107 lines
4.4 KiB
Java
107 lines
4.4 KiB
Java
package com.olympus.hermione.stepSolvers;
|
|
|
|
import ch.qos.logback.classic.Logger;
|
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
|
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
|
import com.olympus.hermione.models.ScenarioExecution;
|
|
import com.olympus.hermione.utility.AttributeParser;
|
|
import org.slf4j.LoggerFactory;
|
|
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
|
import org.springframework.ai.chat.messages.Message;
|
|
import org.springframework.ai.chat.messages.SystemMessage;
|
|
import org.springframework.ai.chat.messages.UserMessage;
|
|
import org.springframework.ai.chat.metadata.Usage;
|
|
|
|
public class AdvancedAIPromptSolver extends StepSolver {
|
|
|
|
private String qai_system_prompt_template;
|
|
private String qai_user_input;
|
|
private String qai_output_variable;
|
|
private boolean qai_load_graph_schema=false;
|
|
private String qai_output_entityType;
|
|
private String qai_custom_memory_id;
|
|
|
|
|
|
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
|
|
|
private void loadParameters(){
|
|
logger.info("Loading parameters");
|
|
|
|
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
|
|
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
|
|
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
|
|
}
|
|
|
|
if(this.qai_load_graph_schema){
|
|
logger.info("Loading graph schema");
|
|
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
|
|
}
|
|
|
|
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
|
|
|
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
|
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
|
|
|
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
|
this.qai_output_entityType = (String) this.step.getAttributes().get("qai_output_entityType");
|
|
|
|
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
|
|
//TODO: Add memory ID attribute to have the possibility of multiple conversations
|
|
|
|
}
|
|
|
|
@Override
|
|
public ScenarioExecution solveStep(){
|
|
|
|
System.out.println("Solving step: " + this.step.getName());
|
|
|
|
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
|
|
|
loadParameters();
|
|
|
|
String userText = this.qai_user_input;
|
|
|
|
Message userMessage = new UserMessage(userText);
|
|
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
|
|
|
|
CallResponseSpec resp = chatClient.prompt()
|
|
.messages(userMessage,systemMessage)
|
|
.advisors(advisor -> advisor
|
|
.param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
|
|
.param("chat_memory_response_size", 100))
|
|
.call();
|
|
|
|
if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
|
|
logger.info("Output is of type CiaOutputEntity");
|
|
|
|
CiaOutputEntity ouputEntity = resp.entity(CiaOutputEntity.class);
|
|
|
|
try {
|
|
ObjectMapper objectMapper = new ObjectMapper();
|
|
String jsonOutput = objectMapper.writeValueAsString(ouputEntity);
|
|
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, jsonOutput);
|
|
} catch (JsonProcessingException e) {
|
|
logger.error("Error converting output entity to JSON", e);
|
|
}
|
|
}else{
|
|
|
|
String output = resp.content();
|
|
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
|
}
|
|
|
|
Usage usage = resp.chatResponse().getMetadata().getUsage();
|
|
if (usage != null) {
|
|
Integer usedTokens = usage.getTotalTokens();
|
|
this.scenarioExecution.setUsedTokens(usedTokens);
|
|
} else {
|
|
logger.info("Token usage information is not available.");
|
|
}
|
|
|
|
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
|
|
|
return this.scenarioExecution;
|
|
}
|
|
|
|
}
|