new OlynmpusChatClientSolver

This commit is contained in:
andrea.terzani
2025-02-20 17:31:22 +01:00
parent 7def3b2770
commit 170a2dcfc9
6 changed files with 236 additions and 1 deletions

View File

@@ -0,0 +1,94 @@
package com.olympus.hermione.client;
import java.util.HashMap;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
public class AzureOpenApiChatClient implements OlympusChatClient {
private String apiUrl;
private String apiKey;
private int maxOutputTokens;
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
@Override
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
this.apiUrl = apiUrl;
this.apiKey = apiKey;
this.maxOutputTokens = maxOutputTokens;
}
@Override
public OlympusChatClientResponse getChatCompletion(String userInput) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("api-key", apiKey);
Map<String, Object> requestBody = new HashMap<>();
requestBody.put("messages", new Object[]{
Map.of("role", "user", "content", userInput)
});
requestBody.put("max_completion_tokens", maxOutputTokens);
HttpEntity<Map<String, Object>> request = new HttpEntity<>(requestBody, headers);
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> response = restTemplate.exchange(apiUrl, HttpMethod.POST, request, String.class);
ObjectMapper objectMapper = new ObjectMapper();
JsonNode rootNode;
try {
rootNode = objectMapper.readTree(response.getBody());
} catch (JsonProcessingException e) {
logger.error(e.getMessage());
return null;
}
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
olympusChatClientResponse.setContent(extractMessageContent(response.getBody()));
olympusChatClientResponse.setInputTokens(rootNode.path("usage").path("prompt_tokens").asInt());
olympusChatClientResponse.setOutputTokens(rootNode.path("usage").path("completion_tokens").asInt());
olympusChatClientResponse.setTotalTokens(rootNode.path("usage").path("total_tokens").asInt());
return olympusChatClientResponse;
}
public String extractMessageContent(String jsonResponse) {
try {
ObjectMapper objectMapper = new ObjectMapper();
JsonNode rootNode = objectMapper.readTree(jsonResponse);
JsonNode choicesNode = rootNode.path("choices");
if (choicesNode.isArray() && choicesNode.size() > 0) {
JsonNode messageNode = choicesNode.get(0).path("message");
JsonNode contentNode = messageNode.path("content");
return contentNode.asText();
}
} catch (Exception e) {
logger.error(e.getMessage());
}
return null;
}
}

View File

@@ -0,0 +1,8 @@
package com.olympus.hermione.client;
public interface OlympusChatClient {
public OlympusChatClientResponse getChatCompletion(String userInput);
public void init(String apiUrl, String apiKey, int maxOutputTokens);
}

View File

@@ -0,0 +1,46 @@
package com.olympus.hermione.client;
public class OlympusChatClientResponse {
private String content;
private int inputTokens;
private int outputTokens;
private int totalTokens;
public OlympusChatClientResponse() {
}
public OlympusChatClientResponse(String content, int inputTokens, int outputTokens, int totalTokens) {
this.content = content;
this.inputTokens = inputTokens;
this.outputTokens = outputTokens;
this.totalTokens = totalTokens;
}
public String getContent() {
return content;
}
public void setContent(String content) {
this.content = content;
}
public int getInputTokens() {
return inputTokens;
}
public void setInputTokens(int inputTokens) {
this.inputTokens = inputTokens;
}
public int getOutputTokens() {
return outputTokens;
}
public void setOutputTokens(int outputTokens) {
this.outputTokens = outputTokens;
}
public int getTotalTokens() {
return totalTokens;
}
public void setTotalTokens(int totalTokens) {
this.totalTokens = totalTokens;
}
}

View File

@@ -34,6 +34,7 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.dto.ScenarioExecutionInput;
import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.models.AiModel;
@@ -60,6 +61,7 @@ import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import com.olympus.hermione.stepSolvers.ExternalAgentSolver;
import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver;
import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver;
@@ -262,6 +264,9 @@ public class ScenarioExecutionService {
case "SUMMARIZE_DOC":
solver = new SummarizeDocSolver();
break;
case "OLYMPUS_QUERY_AI":
solver = new OlynmpusChatClientSolver();
break;
default:
break;
}
@@ -476,7 +481,7 @@ public class ScenarioExecutionService {
return result;
}
}
public String updateRating(String id, String rating){
logger.info("updateRating function:");

View File

@@ -0,0 +1,80 @@
package com.olympus.hermione.stepSolvers;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.olympus.hermione.client.AzureOpenApiChatClient;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.client.OlympusChatClientResponse;
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class OlynmpusChatClientSolver extends StepSolver{
private String qai_system_prompt_template;
private String qai_user_input;
private String qai_output_variable;
private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
}
@Override
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
OlympusChatClient chatClient;
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
chatClient = new AzureOpenApiChatClient();
} else {
chatClient = null;
}
chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(),
scenarioExecution.getScenario().getAiModel().getApiKey(),
scenarioExecution.getScenario().getAiModel().getMaxTokens());
String userText = this.qai_user_input;
String systemText = this.qai_system_prompt_template;
OlympusChatClientResponse resp = chatClient.getChatCompletion(systemText +"\n"+ userText);
String output = resp.getContent();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
Long usedTokens = (long) resp.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution;
}
}

View File

@@ -59,6 +59,7 @@ hermione.fe.url = http://localhost:5173
java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080
java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080
jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080
spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
spring.ai.vectorstore.chroma.client.port=8000
@@ -69,3 +70,4 @@ file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/
spring.servlet.multipart.max-file-size=10MB
spring.servlet.multipart.max-request-size=10MB
generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080