From 170a2dcfc9a743121a05eb71feb1f24ca05fc369 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Thu, 20 Feb 2025 17:31:22 +0100 Subject: [PATCH 01/10] new OlynmpusChatClientSolver --- .../client/AzureOpenApiChatClient.java | 94 +++++++++++++++++++ .../hermione/client/OlympusChatClient.java | 8 ++ .../client/OlympusChatClientResponse.java | 46 +++++++++ .../services/ScenarioExecutionService.java | 7 +- .../stepSolvers/OlynmpusChatClientSolver.java | 80 ++++++++++++++++ src/main/resources/application.properties | 2 + 6 files changed, 236 insertions(+), 1 deletion(-) create mode 100644 src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java create mode 100644 src/main/java/com/olympus/hermione/client/OlympusChatClient.java create mode 100644 src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java diff --git a/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java b/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java new file mode 100644 index 0000000..3ed6d83 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/AzureOpenApiChatClient.java @@ -0,0 +1,94 @@ +package com.olympus.hermione.client; + +import java.util.HashMap; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class AzureOpenApiChatClient implements OlympusChatClient { + + private String apiUrl; + private String apiKey; + private int maxOutputTokens; + + Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class); + + @Override + public void init(String apiUrl, String apiKey, int maxOutputTokens) { + this.apiUrl = apiUrl; + this.apiKey = apiKey; + this.maxOutputTokens = maxOutputTokens; + + } + + @Override + public OlympusChatClientResponse getChatCompletion(String userInput) { + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("api-key", apiKey); + + Map requestBody = new HashMap<>(); + + requestBody.put("messages", new Object[]{ + Map.of("role", "user", "content", userInput) + }); + requestBody.put("max_completion_tokens", maxOutputTokens); + + + HttpEntity> request = new HttpEntity<>(requestBody, headers); + RestTemplate restTemplate = new RestTemplate(); + ResponseEntity response = restTemplate.exchange(apiUrl, HttpMethod.POST, request, String.class); + + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode; + try { + rootNode = objectMapper.readTree(response.getBody()); + } catch (JsonProcessingException e) { + logger.error(e.getMessage()); + return null; + } + + OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse(); + olympusChatClientResponse.setContent(extractMessageContent(response.getBody())); + olympusChatClientResponse.setInputTokens(rootNode.path("usage").path("prompt_tokens").asInt()); + olympusChatClientResponse.setOutputTokens(rootNode.path("usage").path("completion_tokens").asInt()); + olympusChatClientResponse.setTotalTokens(rootNode.path("usage").path("total_tokens").asInt()); + + + return olympusChatClientResponse; + + } + + + public String extractMessageContent(String jsonResponse) { + + + try { + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode = objectMapper.readTree(jsonResponse); + JsonNode choicesNode = rootNode.path("choices"); + if (choicesNode.isArray() && choicesNode.size() > 0) { + JsonNode messageNode = choicesNode.get(0).path("message"); + JsonNode contentNode = messageNode.path("content"); + return contentNode.asText(); + } + } catch (Exception e) { + logger.error(e.getMessage()); + + } + return null; + } + +} diff --git a/src/main/java/com/olympus/hermione/client/OlympusChatClient.java b/src/main/java/com/olympus/hermione/client/OlympusChatClient.java new file mode 100644 index 0000000..03b7d06 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/OlympusChatClient.java @@ -0,0 +1,8 @@ +package com.olympus.hermione.client; + +public interface OlympusChatClient { + + public OlympusChatClientResponse getChatCompletion(String userInput); + public void init(String apiUrl, String apiKey, int maxOutputTokens); + +} diff --git a/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java b/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java new file mode 100644 index 0000000..c544e90 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/OlympusChatClientResponse.java @@ -0,0 +1,46 @@ +package com.olympus.hermione.client; + +public class OlympusChatClientResponse { + + private String content; + private int inputTokens; + private int outputTokens; + private int totalTokens; + + public OlympusChatClientResponse() { + } + + public OlympusChatClientResponse(String content, int inputTokens, int outputTokens, int totalTokens) { + this.content = content; + this.inputTokens = inputTokens; + this.outputTokens = outputTokens; + this.totalTokens = totalTokens; + } + + public String getContent() { + return content; + } + public void setContent(String content) { + this.content = content; + } + public int getInputTokens() { + return inputTokens; + } + public void setInputTokens(int inputTokens) { + this.inputTokens = inputTokens; + } + public int getOutputTokens() { + return outputTokens; + } + public void setOutputTokens(int outputTokens) { + this.outputTokens = outputTokens; + } + public int getTotalTokens() { + return totalTokens; + } + public void setTotalTokens(int totalTokens) { + this.totalTokens = totalTokens; + } + + +} diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 639655a..4bfcaef 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -34,6 +34,7 @@ import org.springframework.scheduling.annotation.Async; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Service; +import com.olympus.hermione.client.OlympusChatClient; import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.models.AiModel; @@ -60,6 +61,7 @@ import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import com.olympus.hermione.stepSolvers.ExternalAgentSolver; import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver; +import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver; @@ -262,6 +264,9 @@ public class ScenarioExecutionService { case "SUMMARIZE_DOC": solver = new SummarizeDocSolver(); break; + case "OLYMPUS_QUERY_AI": + solver = new OlynmpusChatClientSolver(); + break; default: break; } @@ -476,7 +481,7 @@ public class ScenarioExecutionService { return result; -} + } public String updateRating(String id, String rating){ logger.info("updateRating function:"); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java new file mode 100644 index 0000000..9d07732 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -0,0 +1,80 @@ +package com.olympus.hermione.stepSolvers; + +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; +import com.olympus.hermione.dto.aientity.CiaOutputEntity; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class OlynmpusChatClientSolver extends StepSolver{ + + private String qai_system_prompt_template; + private String qai_user_input; + private String qai_output_variable; + private boolean qai_load_graph_schema=false; + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template")); + this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input")); + this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable"); + + + } + + @Override + public ScenarioExecution solveStep(){ + + System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + + OlympusChatClient chatClient; + + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { + chatClient = new AzureOpenApiChatClient(); + } else { + chatClient = null; + } + + chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(), + scenarioExecution.getScenario().getAiModel().getApiKey(), + scenarioExecution.getScenario().getAiModel().getMaxTokens()); + String userText = this.qai_user_input; + String systemText = this.qai_system_prompt_template; + + OlympusChatClientResponse resp = chatClient.getChatCompletion(systemText +"\n"+ userText); + + String output = resp.getContent(); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + + + + Long usedTokens = (long) resp.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } +} diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index cf469cc..f5e540e 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -59,6 +59,7 @@ hermione.fe.url = http://localhost:5173 java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080 java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080 +jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080 spring.ai.vectorstore.chroma.client.host=http://108.142.74.161 spring.ai.vectorstore.chroma.client.port=8000 @@ -69,3 +70,4 @@ file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/ spring.servlet.multipart.max-file-size=10MB spring.servlet.multipart.max-request-size=10MB +generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080 \ No newline at end of file From 255e65af5e137ce2f3b7d9ac04c62deb27ec0244 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Thu, 20 Feb 2025 19:12:48 +0100 Subject: [PATCH 02/10] Aggiungi supporto per Google Gemini nel servizio di esecuzione degli scenari --- .../client/GoogleGeminiChatClient.java | 100 ++++++++++++++++++ .../services/ScenarioExecutionService.java | 17 ++- .../stepSolvers/OlynmpusChatClientSolver.java | 7 +- 3 files changed, 122 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java diff --git a/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java b/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java new file mode 100644 index 0000000..0082222 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java @@ -0,0 +1,100 @@ +package com.olympus.hermione.client; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class GoogleGeminiChatClient implements OlympusChatClient { + + private String apiUrl; + private String apiKey; + private int maxOutputTokens; + + Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class); + + @Override + public void init(String apiUrl, String apiKey, int maxOutputTokens) { + this.apiUrl = apiUrl; + this.apiKey = apiKey; + this.maxOutputTokens = maxOutputTokens; + + } + + @Override + public OlympusChatClientResponse getChatCompletion(String userInput) { + + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + + Map userMessage = new HashMap<>(); + userMessage.put("role", "user"); + userMessage.put("parts", new Object[]{Map.of("text", userInput)}); + + + Map conversation = new HashMap<>(); + conversation.put("contents", new Object[]{userMessage}); + + Map generationConfig = new HashMap<>(); + /* generationConfig.put("temperature", 1); + generationConfig.put("topK", 40); + generationConfig.put("topP", 0.95);*/ + generationConfig.put("maxOutputTokens", maxOutputTokens); + generationConfig.put("responseMimeType", "text/plain"); + + conversation.put("generationConfig", generationConfig); + + HttpEntity> request = new HttpEntity<>(conversation, headers); + RestTemplate restTemplate = new RestTemplate(); + + String fullUrl = apiUrl + "?key=" + apiKey; + ResponseEntity response = restTemplate.exchange(fullUrl, HttpMethod.POST, request, String.class); + OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse(); + + + try{ + Map responseBody = new ObjectMapper().readValue(response.getBody(), Map.class); + List> candidates = (List>) responseBody.get("candidates"); + Map firstCandidate = candidates.get(0); + Map content = (Map) firstCandidate.get("content"); + List> parts = (List>) content.get("parts"); + String text = (String) parts.get(0).get("text"); + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode; + try { + rootNode = objectMapper.readTree(response.getBody()); + } catch (JsonProcessingException e) { + logger.error(e.getMessage()); + return null; + } + + olympusChatClientResponse.setContent(text); + olympusChatClientResponse.setInputTokens(rootNode.path("usageMetadata").path("promptTokenCount").asInt()); + olympusChatClientResponse.setOutputTokens(rootNode.path("usageMetadata").path("candidatesTokenCount").asInt()); + olympusChatClientResponse.setTotalTokens(rootNode.path("usageMetadata").path("totalTokenCount").asInt()); + }catch(Exception e){ + logger.error("Error parsing response", e); + return null; + } + + + return olympusChatClientResponse; + + } + + + +} diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 4bfcaef..412f409 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -445,7 +445,22 @@ public class ScenarioExecutionService { OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions); logger.info("AI model used: " + aiModel.getModel()); return openaichatModel; - + + + case "GoogleGemini": + OpenAIClient tempopenAIClient = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(aiModel.getApiKey())) + .endpoint(aiModel.getEndpoint()) + .buildClient(); + AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(aiModel.getModel()) + .withTemperature(aiModel.getTemperature()) + .withMaxTokens(aiModel.getMaxTokens()) + .build(); + + AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions); + logger.info("AI model used: GoogleGemini"); + return tempazureOpenaichatModel; default: throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName()); } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java index 9d07732..b22d841 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -10,6 +10,7 @@ import org.springframework.ai.chat.metadata.Usage; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.GoogleGeminiChatClient; import com.olympus.hermione.client.OlympusChatClient; import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.dto.aientity.CiaOutputEntity; @@ -53,7 +54,11 @@ public class OlynmpusChatClientSolver extends StepSolver{ if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { chatClient = new AzureOpenApiChatClient(); - } else { + } + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) { + chatClient = new GoogleGeminiChatClient(); + } + else { chatClient = null; } From e9c300ca98fe7bdcbb7413ce133ceb1956b989fb Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Thu, 20 Feb 2025 19:28:01 +0100 Subject: [PATCH 03/10] Inizializza chatClient a null per gestire correttamente i provider API --- .../hermione/stepSolvers/OlynmpusChatClientSolver.java | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java index b22d841..8862430 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -50,7 +50,7 @@ public class OlynmpusChatClientSolver extends StepSolver{ loadParameters(); - OlympusChatClient chatClient; + OlympusChatClient chatClient = null; if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { chatClient = new AzureOpenApiChatClient(); @@ -58,9 +58,7 @@ public class OlynmpusChatClientSolver extends StepSolver{ if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) { chatClient = new GoogleGeminiChatClient(); } - else { - chatClient = null; - } + chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(), scenarioExecution.getScenario().getAiModel().getApiKey(), From 37caef5c9db9839a1158c5403e0139b88c1e238f Mon Sep 17 00:00:00 2001 From: Emanuele Ferrelli Date: Fri, 21 Feb 2025 10:01:18 +0100 Subject: [PATCH 04/10] Develop Q&A on Doc scenario --- .../services/ScenarioExecutionService.java | 8 ++ .../stepSolvers/DeleteDocTempSolver.java | 75 +++++++++++++ .../stepSolvers/EmbeddingDocTempSolver.java | 106 ++++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 412f409..4644c1f 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -49,6 +49,8 @@ import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver; import com.olympus.hermione.stepSolvers.SummarizeDocSolver; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; +import com.olympus.hermione.stepSolvers.DeleteDocTempSolver; +import com.olympus.hermione.stepSolvers.EmbeddingDocTempSolver; import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; import com.olympus.hermione.stepSolvers.SourceCodeRagSolver; import com.olympus.hermione.stepSolvers.StepSolver; @@ -267,6 +269,12 @@ public class ScenarioExecutionService { case "OLYMPUS_QUERY_AI": solver = new OlynmpusChatClientSolver(); break; + case "EMBED_TEMPORARY_DOC": + solver = new EmbeddingDocTempSolver(); + break; + case "DELETE_TEMPORARY_DOC": + solver = new DeleteDocTempSolver(); + break; default: break; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java new file mode 100644 index 0000000..266e82f --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java @@ -0,0 +1,75 @@ +package com.olympus.hermione.stepSolvers; + +import ch.qos.logback.classic.Logger; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; +import java.util.List; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; + +public class DeleteDocTempSolver extends StepSolver { + + private String rag_filter; + private int topk; + private double threshold; + private String query; + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + this.scenarioExecution.getExecSharedMap().put("scenario_execution_id", this.scenarioExecution.getId()); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + // this.scenario_execution_id = attributeParser.parse((String) this.scenarioExecution.getId()); + + this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter")); + + if(this.step.getAttributes().containsKey("rag_query")){ + this.query = (String) this.step.getAttributes().get("rag_query"); + }else{ + this.query = "*"; + } + + if(this.step.getAttributes().containsKey("rag_topk")){ + this.topk = (int) this.step.getAttributes().get("rag_topk"); + }else{ + this.topk = 1000; + } + + if(this.step.getAttributes().containsKey("rag_threshold")){ + this.threshold = (double) this.step.getAttributes().get("rag_threshold"); + }else{ + this.threshold = 0.0; + } + + } + + @Override + public ScenarioExecution solveStep(){ + + System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + SearchRequest searchRequest = SearchRequest.defaults() + .withQuery(this.query) + .withTopK(this.topk) + .withSimilarityThreshold(this.threshold) + .withFilterExpression(this.rag_filter); + + + List docs = vectorStore.similaritySearch(searchRequest); + List ids = docs.stream().map(Document::getId).toList(); + vectorStore.delete(ids); + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +} diff --git a/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java new file mode 100644 index 0000000..6b13974 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java @@ -0,0 +1,106 @@ +package com.olympus.hermione.stepSolvers; + +import ch.qos.logback.classic.Logger; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; +import java.io.File; +import java.util.Collections; +import java.util.List; +import org.apache.tika.Tika; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; + +public class EmbeddingDocTempSolver extends StepSolver { + + private String scenario_execution_id; + private String path_file; + private int default_chunk_size; + private int min_chunk_size; + private int min_chunk_length_to_embed; + private int max_num_chunks; + + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + this.scenarioExecution.getExecSharedMap().put("scenario_execution_id", this.scenarioExecution.getId()); + logger.info("Scenario Execution ID: "+this.scenarioExecution.getId()); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.scenario_execution_id = attributeParser.parse((String) this.scenarioExecution.getId()); + + this.path_file = attributeParser.parse((String) this.step.getAttributes().get("path_file")); + + if(this.step.getAttributes().containsKey("default_chunk_size")){ + this.default_chunk_size = (int) this.step.getAttributes().get("default_chunk_size"); + }else{ + this.default_chunk_size = 8000; + } + + if(this.step.getAttributes().containsKey("min_chunk_size")){ + this.min_chunk_size = (int) this.step.getAttributes().get("min_chunk_size"); + }else{ + this.min_chunk_size = 50; + } + + if(this.step.getAttributes().containsKey("min_chunk_length_to_embed")){ + this.min_chunk_length_to_embed = (int) this.step.getAttributes().get("min_chunk_length_to_embed"); + }else{ + this.min_chunk_length_to_embed = 50; + } + + if(this.step.getAttributes().containsKey("max_num_chunks")){ + this.max_num_chunks = (int) this.step.getAttributes().get("max_num_chunks"); + }else{ + this.max_num_chunks = 1000; + } + + } + + @Override + public ScenarioExecution solveStep(){ + + try{ + logger.info("Solving step: " + this.step.getName()); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); + + logger.info("Embedding documents"); + File file = new File(this.path_file); + Tika tika = new Tika(); + tika.setMaxStringLength(-1); + String text = tika.parseToString(file); + Document myDoc = new Document(text); + + List docs = Collections.singletonList(myDoc); + + TokenTextSplitter splitter = new TokenTextSplitter(this.default_chunk_size, + this.min_chunk_size, + this.min_chunk_length_to_embed, + this.max_num_chunks, + true); + + docs.forEach(doc -> { + List splitDocs = splitter.split(doc); + logger.info("Number of documents: " + splitDocs.size()); + + splitDocs.forEach(splitDoc -> { + splitDoc.getMetadata().put("KsScenarioExecutionId", this.scenario_execution_id); + }); + vectorStore.add(splitDocs); + }); + }catch (Exception e){ + logger.error("Error while solvingStep: "+e.getMessage()); + e.printStackTrace(); + } + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +} From 3402b4ada284b8b3f6069415567e18cd5a2a5b8e Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Fri, 21 Feb 2025 13:02:37 +0100 Subject: [PATCH 05/10] Refactor token handling and update dependencies; add AzureAIConfig for response timeout customization --- pom.xml | 9 +++- .../hermione/config/AzureAIConfig.java | 26 ++++++++++ .../hermione/config/VectorStoreConfig.java | 3 +- .../hermione/models/ScenarioExecution.java | 7 +-- .../services/CanvasExecutionService.java | 6 +-- .../services/ScenarioExecutionService.java | 47 +++++++++-------- .../stepSolvers/AdvancedAIPromptSolver.java | 2 +- .../stepSolvers/BasicAIPromptSolver.java | 2 +- .../stepSolvers/BasicQueryRagSolver.java | 22 +++++--- .../stepSolvers/DeleteDocTempSolver.java | 13 +++-- .../stepSolvers/OlynmpusChatClientSolver.java | 4 +- .../stepSolvers/SummarizeDocSolver.java | 4 +- src/main/resources/application.properties | 50 ++++++++++--------- 13 files changed, 122 insertions(+), 73 deletions(-) create mode 100644 src/main/java/com/olympus/hermione/config/AzureAIConfig.java diff --git a/pom.xml b/pom.xml index 2f58e71..7e56c32 100644 --- a/pom.xml +++ b/pom.xml @@ -29,7 +29,7 @@ 21 - 1.0.0-M2 + 1.0.0-M6 2023.0.3 @@ -48,6 +48,13 @@ 1.1.0 + + + org.json + json + 20250107 + +