Files
hermione/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java

107 lines
4.4 KiB
Java

package com.olympus.hermione.stepSolvers;
import ch.qos.logback.classic.Logger;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
public class AdvancedAIPromptSolver extends StepSolver {
private String qai_system_prompt_template;
private String qai_user_input;
private String qai_output_variable;
private boolean qai_load_graph_schema=false;
private String qai_output_entityType;
private String qai_custom_memory_id;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){
logger.info("Loading parameters");
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
}
if(this.qai_load_graph_schema){
logger.info("Loading graph schema");
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
}
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
this.qai_output_entityType = (String) this.step.getAttributes().get("qai_output_entityType");
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
//TODO: Add memory ID attribute to have the possibility of multiple conversations
}
@Override
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
String userText = this.qai_user_input;
Message userMessage = new UserMessage(userText);
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
CallResponseSpec resp = chatClient.prompt()
.messages(userMessage,systemMessage)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.call();
if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
logger.info("Output is of type CiaOutputEntity");
CiaOutputEntity ouputEntity = resp.entity(CiaOutputEntity.class);
try {
ObjectMapper objectMapper = new ObjectMapper();
String jsonOutput = objectMapper.writeValueAsString(ouputEntity);
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, jsonOutput);
} catch (JsonProcessingException e) {
logger.error("Error converting output entity to JSON", e);
}
}else{
String output = resp.content();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
}
Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) {
Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens);
} else {
logger.info("Token usage information is not available.");
}
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution;
}
}