diff --git a/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java b/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java new file mode 100644 index 0000000..3ed6d83 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java @@ -0,0 +1,94 @@ +package com.olympus.hermione.client; + +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class AzureOpenApiChatClient implements OlympusChatClient { + + private String apiUrl; + private String apiKey; + private int maxOutputTokens; + + Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class); + + @Override + public void init(String apiUrl, String apiKey, int maxOutputTokens) { + this.apiUrl = apiUrl; + this.apiKey = apiKey; + this.maxOutputTokens = maxOutputTokens; + + } + + @Override + public OlympusChatClientResponse getChatCompletion(String userInput) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("api-key", apiKey); + + Map requestBody = new HashMap<>(); + + requestBody.put("messages", new Object[]{ + Map.of("role", "user", "content", userInput) + }); + requestBody.put("max_completion_tokens", maxOutputTokens); + + + HttpEntity> request = new HttpEntity<>(requestBody, headers); + RestTemplate restTemplate = new RestTemplate(); + ResponseEntity response = restTemplate.exchange(apiUrl, HttpMethod.POST, request, String.class); + + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode; + try { + rootNode = objectMapper.readTree(response.getBody()); + } catch (JsonProcessingException e) { + logger.error(e.getMessage()); + return null; + } + + OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse(); + olympusChatClientResponse.setContent(extractMessageContent(response.getBody())); + olympusChatClientResponse.setInputTokens(rootNode.path("usage").path("prompt_tokens").asInt()); + olympusChatClientResponse.setOutputTokens(rootNode.path("usage").path("completion_tokens").asInt()); + olympusChatClientResponse.setTotalTokens(rootNode.path("usage").path("total_tokens").asInt()); + + + return olympusChatClientResponse; + + } + + + public String extractMessageContent(String jsonResponse) { + + + try { + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode = objectMapper.readTree(jsonResponse); + JsonNode choicesNode = rootNode.path("choices"); + if (choicesNode.isArray() && choicesNode.size() > 0) { + JsonNode messageNode = choicesNode.get(0).path("message"); + JsonNode contentNode = messageNode.path("content"); + return contentNode.asText(); + } + } catch (Exception e) { + logger.error(e.getMessage()); + + } + return null; + } + +} diff --git a/src/main/java/com/olympus/hermione/client/OlympusChatClient.java b/src/main/java/com/olympus/hermione/client/OlympusChatClient.java new file mode 100644 index 0000000..03b7d06 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/OlympusChatClient.java @@ -0,0 +1,8 @@ +package com.olympus.hermione.client; + +public interface OlympusChatClient { + + public OlympusChatClientResponse getChatCompletion(String userInput); + public void init(String apiUrl, String apiKey, int maxOutputTokens); + +} diff --git a/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java b/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java new file mode 100644 index 0000000..c544e90 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java @@ -0,0 +1,46 @@ +package com.olympus.hermione.client; + +public class OlympusChatClientResponse { + + private String content; + private int inputTokens; + private int outputTokens; + private int totalTokens; + + public OlympusChatClientResponse() { + } + + public OlympusChatClientResponse(String content, int inputTokens, int outputTokens, int totalTokens) { + this.content = content; + this.inputTokens = inputTokens; + this.outputTokens = outputTokens; + this.totalTokens = totalTokens; + } + + public String getContent() { + return content; + } + public void setContent(String content) { + this.content = content; + } + public int getInputTokens() { + return inputTokens; + } + public void setInputTokens(int inputTokens) { + this.inputTokens = inputTokens; + } + public int getOutputTokens() { + return outputTokens; + } + public void setOutputTokens(int outputTokens) { + this.outputTokens = outputTokens; + } + public int getTotalTokens() { + return totalTokens; + } + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } + + +} diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 639655a..4bfcaef 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -34,6 +34,7 @@ import org.springframework.scheduling.annotation.Async; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Service; +import com.olympus.hermione.client.OlympusChatClient; import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.models.AiModel; @@ -60,6 +61,7 @@ import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import com.olympus.hermione.stepSolvers.ExternalAgentSolver; import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver; +import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver; @@ -262,6 +264,9 @@ public class ScenarioExecutionService { case "SUMMARIZE_DOC": solver = new SummarizeDocSolver(); break; + case "OLYMPUS_QUERY_AI": + solver = new OlynmpusChatClientSolver(); + break; default: break; } @@ -476,7 +481,7 @@ public class ScenarioExecutionService { return result; -} + } public String updateRating(String id, String rating){ logger.info("updateRating function:"); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java new file mode 100644 index 0000000..9d07732 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -0,0 +1,80 @@ +package com.olympus.hermione.stepSolvers; + +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; +import com.olympus.hermione.dto.aientity.CiaOutputEntity; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class OlynmpusChatClientSolver extends StepSolver{ + + private String qai_system_prompt_template; + private String qai_user_input; + private String qai_output_variable; + private boolean qai_load_graph_schema=false; + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template")); + this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input")); + this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable"); + + + } + + @Override + public ScenarioExecution solveStep(){ + + System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + + OlympusChatClient chatClient; + + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { + chatClient = new AzureOpenApiChatClient(); + } else { + chatClient = null; + } + + chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(), + scenarioExecution.getScenario().getAiModel().getApiKey(), + scenarioExecution.getScenario().getAiModel().getMaxTokens()); + String userText = this.qai_user_input; + String systemText = this.qai_system_prompt_template; + + OlympusChatClientResponse resp = chatClient.getChatCompletion(systemText +"\n"+ userText); + + String output = resp.getContent(); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + + + + Long usedTokens = (long) resp.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } +} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index cf469cc..f5e540e 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -59,6 +59,7 @@ hermione.fe.url = http://localhost:5173 java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080 java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080 +jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080 spring.ai.vectorstore.chroma.client.host=http://108.142.74.161 spring.ai.vectorstore.chroma.client.port=8000 @@ -69,3 +70,4 @@ file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/ spring.servlet.multipart.max-file-size=10MB spring.servlet.multipart.max-request-size=10MB +generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080 \ No newline at end of file