Files
hermione/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
paola.trabucco 7c20fbac45 authentication
2024-08-07 12:37:05 +02:00

62 lines
2.0 KiB
Java

package com.olympus.hermione.stepSolvers;
import java.util.List;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class BasicAIPromptSolver extends StepSolver {
private String qai_system_prompt_template;
private String qai_user_input;
private String qai_output_variable;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
}
@Override
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
loadParameters();
String userText = this.qai_user_input;
Message userMessage = new UserMessage(userText);
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
List<Generation> response = chatModel.call(prompt).getResults();
String output = response.get(0).getOutput().getContent();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
return this.scenarioExecution;
}
}