new OlynmpusChatClientSolver
This commit is contained in:
@@ -0,0 +1,94 @@
|
|||||||
|
package com.olympus.hermione.client;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
public class AzureOpenApiChatClient implements OlympusChatClient {
|
||||||
|
|
||||||
|
private String apiUrl;
|
||||||
|
private String apiKey;
|
||||||
|
private int maxOutputTokens;
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
|
||||||
|
this.apiUrl = apiUrl;
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
this.maxOutputTokens = maxOutputTokens;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OlympusChatClientResponse getChatCompletion(String userInput) {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
headers.set("api-key", apiKey);
|
||||||
|
|
||||||
|
Map<String, Object> requestBody = new HashMap<>();
|
||||||
|
|
||||||
|
requestBody.put("messages", new Object[]{
|
||||||
|
Map.of("role", "user", "content", userInput)
|
||||||
|
});
|
||||||
|
requestBody.put("max_completion_tokens", maxOutputTokens);
|
||||||
|
|
||||||
|
|
||||||
|
HttpEntity<Map<String, Object>> request = new HttpEntity<>(requestBody, headers);
|
||||||
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
|
ResponseEntity<String> response = restTemplate.exchange(apiUrl, HttpMethod.POST, request, String.class);
|
||||||
|
|
||||||
|
|
||||||
|
ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
JsonNode rootNode;
|
||||||
|
try {
|
||||||
|
rootNode = objectMapper.readTree(response.getBody());
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
logger.error(e.getMessage());
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
|
||||||
|
olympusChatClientResponse.setContent(extractMessageContent(response.getBody()));
|
||||||
|
olympusChatClientResponse.setInputTokens(rootNode.path("usage").path("prompt_tokens").asInt());
|
||||||
|
olympusChatClientResponse.setOutputTokens(rootNode.path("usage").path("completion_tokens").asInt());
|
||||||
|
olympusChatClientResponse.setTotalTokens(rootNode.path("usage").path("total_tokens").asInt());
|
||||||
|
|
||||||
|
|
||||||
|
return olympusChatClientResponse;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
public String extractMessageContent(String jsonResponse) {
|
||||||
|
|
||||||
|
|
||||||
|
try {
|
||||||
|
ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
JsonNode rootNode = objectMapper.readTree(jsonResponse);
|
||||||
|
JsonNode choicesNode = rootNode.path("choices");
|
||||||
|
if (choicesNode.isArray() && choicesNode.size() > 0) {
|
||||||
|
JsonNode messageNode = choicesNode.get(0).path("message");
|
||||||
|
JsonNode contentNode = messageNode.path("content");
|
||||||
|
return contentNode.asText();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.error(e.getMessage());
|
||||||
|
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,8 @@
|
|||||||
|
package com.olympus.hermione.client;
|
||||||
|
|
||||||
|
public interface OlympusChatClient {
|
||||||
|
|
||||||
|
public OlympusChatClientResponse getChatCompletion(String userInput);
|
||||||
|
public void init(String apiUrl, String apiKey, int maxOutputTokens);
|
||||||
|
|
||||||
|
}
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
package com.olympus.hermione.client;
|
||||||
|
|
||||||
|
public class OlympusChatClientResponse {
|
||||||
|
|
||||||
|
private String content;
|
||||||
|
private int inputTokens;
|
||||||
|
private int outputTokens;
|
||||||
|
private int totalTokens;
|
||||||
|
|
||||||
|
public OlympusChatClientResponse() {
|
||||||
|
}
|
||||||
|
|
||||||
|
public OlympusChatClientResponse(String content, int inputTokens, int outputTokens, int totalTokens) {
|
||||||
|
this.content = content;
|
||||||
|
this.inputTokens = inputTokens;
|
||||||
|
this.outputTokens = outputTokens;
|
||||||
|
this.totalTokens = totalTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
public String getContent() {
|
||||||
|
return content;
|
||||||
|
}
|
||||||
|
public void setContent(String content) {
|
||||||
|
this.content = content;
|
||||||
|
}
|
||||||
|
public int getInputTokens() {
|
||||||
|
return inputTokens;
|
||||||
|
}
|
||||||
|
public void setInputTokens(int inputTokens) {
|
||||||
|
this.inputTokens = inputTokens;
|
||||||
|
}
|
||||||
|
public int getOutputTokens() {
|
||||||
|
return outputTokens;
|
||||||
|
}
|
||||||
|
public void setOutputTokens(int outputTokens) {
|
||||||
|
this.outputTokens = outputTokens;
|
||||||
|
}
|
||||||
|
public int getTotalTokens() {
|
||||||
|
return totalTokens;
|
||||||
|
}
|
||||||
|
public void setTotalTokens(int totalTokens) {
|
||||||
|
this.totalTokens = totalTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -34,6 +34,7 @@ import org.springframework.scheduling.annotation.Async;
|
|||||||
import org.springframework.security.core.context.SecurityContextHolder;
|
import org.springframework.security.core.context.SecurityContextHolder;
|
||||||
import org.springframework.stereotype.Service;
|
import org.springframework.stereotype.Service;
|
||||||
|
|
||||||
|
import com.olympus.hermione.client.OlympusChatClient;
|
||||||
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
||||||
import com.olympus.hermione.dto.ScenarioOutput;
|
import com.olympus.hermione.dto.ScenarioOutput;
|
||||||
import com.olympus.hermione.models.AiModel;
|
import com.olympus.hermione.models.AiModel;
|
||||||
@@ -60,6 +61,7 @@ import com.azure.ai.openai.OpenAIClientBuilder;
|
|||||||
import com.azure.core.credential.AzureKeyCredential;
|
import com.azure.core.credential.AzureKeyCredential;
|
||||||
import com.olympus.hermione.stepSolvers.ExternalAgentSolver;
|
import com.olympus.hermione.stepSolvers.ExternalAgentSolver;
|
||||||
import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver;
|
import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver;
|
||||||
|
import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -262,6 +264,9 @@ public class ScenarioExecutionService {
|
|||||||
case "SUMMARIZE_DOC":
|
case "SUMMARIZE_DOC":
|
||||||
solver = new SummarizeDocSolver();
|
solver = new SummarizeDocSolver();
|
||||||
break;
|
break;
|
||||||
|
case "OLYMPUS_QUERY_AI":
|
||||||
|
solver = new OlynmpusChatClientSolver();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -476,7 +481,7 @@ public class ScenarioExecutionService {
|
|||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public String updateRating(String id, String rating){
|
public String updateRating(String id, String rating){
|
||||||
logger.info("updateRating function:");
|
logger.info("updateRating function:");
|
||||||
|
|||||||
@@ -0,0 +1,80 @@
|
|||||||
|
package com.olympus.hermione.stepSolvers;
|
||||||
|
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.SystemMessage;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import org.springframework.ai.chat.metadata.Usage;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
import com.olympus.hermione.client.AzureOpenApiChatClient;
|
||||||
|
import com.olympus.hermione.client.OlympusChatClient;
|
||||||
|
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||||
|
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
||||||
|
import com.olympus.hermione.models.ScenarioExecution;
|
||||||
|
import com.olympus.hermione.utility.AttributeParser;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.Logger;
|
||||||
|
|
||||||
|
public class OlynmpusChatClientSolver extends StepSolver{
|
||||||
|
|
||||||
|
private String qai_system_prompt_template;
|
||||||
|
private String qai_user_input;
|
||||||
|
private String qai_output_variable;
|
||||||
|
private boolean qai_load_graph_schema=false;
|
||||||
|
|
||||||
|
|
||||||
|
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||||
|
|
||||||
|
private void loadParameters(){
|
||||||
|
logger.info("Loading parameters");
|
||||||
|
|
||||||
|
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||||
|
|
||||||
|
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||||
|
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
||||||
|
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScenarioExecution solveStep(){
|
||||||
|
|
||||||
|
System.out.println("Solving step: " + this.step.getName());
|
||||||
|
|
||||||
|
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||||
|
|
||||||
|
loadParameters();
|
||||||
|
|
||||||
|
OlympusChatClient chatClient;
|
||||||
|
|
||||||
|
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
|
||||||
|
chatClient = new AzureOpenApiChatClient();
|
||||||
|
} else {
|
||||||
|
chatClient = null;
|
||||||
|
}
|
||||||
|
|
||||||
|
chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(),
|
||||||
|
scenarioExecution.getScenario().getAiModel().getApiKey(),
|
||||||
|
scenarioExecution.getScenario().getAiModel().getMaxTokens());
|
||||||
|
String userText = this.qai_user_input;
|
||||||
|
String systemText = this.qai_system_prompt_template;
|
||||||
|
|
||||||
|
OlympusChatClientResponse resp = chatClient.getChatCompletion(systemText +"\n"+ userText);
|
||||||
|
|
||||||
|
String output = resp.getContent();
|
||||||
|
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Long usedTokens = (long) resp.getTotalTokens();
|
||||||
|
this.scenarioExecution.setUsedTokens(usedTokens);
|
||||||
|
|
||||||
|
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||||
|
|
||||||
|
return this.scenarioExecution;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,6 +59,7 @@ hermione.fe.url = http://localhost:5173
|
|||||||
|
|
||||||
java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080
|
java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080
|
||||||
java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080
|
java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080
|
||||||
|
jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080
|
||||||
|
|
||||||
spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
|
spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
|
||||||
spring.ai.vectorstore.chroma.client.port=8000
|
spring.ai.vectorstore.chroma.client.port=8000
|
||||||
@@ -69,3 +70,4 @@ file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/
|
|||||||
|
|
||||||
spring.servlet.multipart.max-file-size=10MB
|
spring.servlet.multipart.max-file-size=10MB
|
||||||
spring.servlet.multipart.max-request-size=10MB
|
spring.servlet.multipart.max-request-size=10MB
|
||||||
|
generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080
|
||||||
Reference in New Issue
Block a user