new OlynmpusChatClientSolver
This commit is contained in:
@@ -0,0 +1,94 @@
|
||||
package com.olympus.hermione.client;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
public class AzureOpenApiChatClient implements OlympusChatClient {
|
||||
|
||||
private String apiUrl;
|
||||
private String apiKey;
|
||||
private int maxOutputTokens;
|
||||
|
||||
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
|
||||
|
||||
@Override
|
||||
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
|
||||
this.apiUrl = apiUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.maxOutputTokens = maxOutputTokens;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public OlympusChatClientResponse getChatCompletion(String userInput) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
headers.set("api-key", apiKey);
|
||||
|
||||
Map<String, Object> requestBody = new HashMap<>();
|
||||
|
||||
requestBody.put("messages", new Object[]{
|
||||
Map.of("role", "user", "content", userInput)
|
||||
});
|
||||
requestBody.put("max_completion_tokens", maxOutputTokens);
|
||||
|
||||
|
||||
HttpEntity<Map<String, Object>> request = new HttpEntity<>(requestBody, headers);
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
ResponseEntity<String> response = restTemplate.exchange(apiUrl, HttpMethod.POST, request, String.class);
|
||||
|
||||
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode rootNode;
|
||||
try {
|
||||
rootNode = objectMapper.readTree(response.getBody());
|
||||
} catch (JsonProcessingException e) {
|
||||
logger.error(e.getMessage());
|
||||
return null;
|
||||
}
|
||||
|
||||
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
|
||||
olympusChatClientResponse.setContent(extractMessageContent(response.getBody()));
|
||||
olympusChatClientResponse.setInputTokens(rootNode.path("usage").path("prompt_tokens").asInt());
|
||||
olympusChatClientResponse.setOutputTokens(rootNode.path("usage").path("completion_tokens").asInt());
|
||||
olympusChatClientResponse.setTotalTokens(rootNode.path("usage").path("total_tokens").asInt());
|
||||
|
||||
|
||||
return olympusChatClientResponse;
|
||||
|
||||
}
|
||||
|
||||
|
||||
public String extractMessageContent(String jsonResponse) {
|
||||
|
||||
|
||||
try {
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode rootNode = objectMapper.readTree(jsonResponse);
|
||||
JsonNode choicesNode = rootNode.path("choices");
|
||||
if (choicesNode.isArray() && choicesNode.size() > 0) {
|
||||
JsonNode messageNode = choicesNode.get(0).path("message");
|
||||
JsonNode contentNode = messageNode.path("content");
|
||||
return contentNode.asText();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
logger.error(e.getMessage());
|
||||
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
package com.olympus.hermione.client;
|
||||
|
||||
public interface OlympusChatClient {
|
||||
|
||||
public OlympusChatClientResponse getChatCompletion(String userInput);
|
||||
public void init(String apiUrl, String apiKey, int maxOutputTokens);
|
||||
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
package com.olympus.hermione.client;
|
||||
|
||||
public class OlympusChatClientResponse {
|
||||
|
||||
private String content;
|
||||
private int inputTokens;
|
||||
private int outputTokens;
|
||||
private int totalTokens;
|
||||
|
||||
public OlympusChatClientResponse() {
|
||||
}
|
||||
|
||||
public OlympusChatClientResponse(String content, int inputTokens, int outputTokens, int totalTokens) {
|
||||
this.content = content;
|
||||
this.inputTokens = inputTokens;
|
||||
this.outputTokens = outputTokens;
|
||||
this.totalTokens = totalTokens;
|
||||
}
|
||||
|
||||
public String getContent() {
|
||||
return content;
|
||||
}
|
||||
public void setContent(String content) {
|
||||
this.content = content;
|
||||
}
|
||||
public int getInputTokens() {
|
||||
return inputTokens;
|
||||
}
|
||||
public void setInputTokens(int inputTokens) {
|
||||
this.inputTokens = inputTokens;
|
||||
}
|
||||
public int getOutputTokens() {
|
||||
return outputTokens;
|
||||
}
|
||||
public void setOutputTokens(int outputTokens) {
|
||||
this.outputTokens = outputTokens;
|
||||
}
|
||||
public int getTotalTokens() {
|
||||
return totalTokens;
|
||||
}
|
||||
public void setTotalTokens(int totalTokens) {
|
||||
this.totalTokens = totalTokens;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -34,6 +34,7 @@ import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.security.core.context.SecurityContextHolder;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.olympus.hermione.client.OlympusChatClient;
|
||||
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
||||
import com.olympus.hermione.dto.ScenarioOutput;
|
||||
import com.olympus.hermione.models.AiModel;
|
||||
@@ -60,6 +61,7 @@ import com.azure.ai.openai.OpenAIClientBuilder;
|
||||
import com.azure.core.credential.AzureKeyCredential;
|
||||
import com.olympus.hermione.stepSolvers.ExternalAgentSolver;
|
||||
import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver;
|
||||
import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver;
|
||||
|
||||
|
||||
|
||||
@@ -262,6 +264,9 @@ public class ScenarioExecutionService {
|
||||
case "SUMMARIZE_DOC":
|
||||
solver = new SummarizeDocSolver();
|
||||
break;
|
||||
case "OLYMPUS_QUERY_AI":
|
||||
solver = new OlynmpusChatClientSolver();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -476,7 +481,7 @@ public class ScenarioExecutionService {
|
||||
|
||||
return result;
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
public String updateRating(String id, String rating){
|
||||
logger.info("updateRating function:");
|
||||
|
||||
@@ -0,0 +1,80 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.olympus.hermione.client.AzureOpenApiChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
public class OlynmpusChatClientSolver extends StepSolver{
|
||||
|
||||
private String qai_system_prompt_template;
|
||||
private String qai_user_input;
|
||||
private String qai_output_variable;
|
||||
private boolean qai_load_graph_schema=false;
|
||||
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
||||
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
||||
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScenarioExecution solveStep(){
|
||||
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
|
||||
loadParameters();
|
||||
|
||||
OlympusChatClient chatClient;
|
||||
|
||||
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
|
||||
chatClient = new AzureOpenApiChatClient();
|
||||
} else {
|
||||
chatClient = null;
|
||||
}
|
||||
|
||||
chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(),
|
||||
scenarioExecution.getScenario().getAiModel().getApiKey(),
|
||||
scenarioExecution.getScenario().getAiModel().getMaxTokens());
|
||||
String userText = this.qai_user_input;
|
||||
String systemText = this.qai_system_prompt_template;
|
||||
|
||||
OlympusChatClientResponse resp = chatClient.getChatCompletion(systemText +"\n"+ userText);
|
||||
|
||||
String output = resp.getContent();
|
||||
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||
|
||||
|
||||
|
||||
Long usedTokens = (long) resp.getTotalTokens();
|
||||
this.scenarioExecution.setUsedTokens(usedTokens);
|
||||
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,7 @@ hermione.fe.url = http://localhost:5173
|
||||
|
||||
java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080
|
||||
java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080
|
||||
jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080
|
||||
|
||||
spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
|
||||
spring.ai.vectorstore.chroma.client.port=8000
|
||||
@@ -69,3 +70,4 @@ file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/
|
||||
|
||||
spring.servlet.multipart.max-file-size=10MB
|
||||
spring.servlet.multipart.max-request-size=10MB
|
||||
generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080
|
||||
Reference in New Issue
Block a user