From 06b3491f09e80e8abb57d436f40eea17d80675ef Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Tue, 25 Mar 2025 10:51:18 +0100 Subject: [PATCH] Aggiungi supporto per diversi client di chat e modifica la gestione degli endpoint nel solver delle query avanzate --- .../com/olympus/hermione/models/AiModel.java | 1 + .../stepSolvers/AdvancedQueryRagSolver.java | 43 +++++++++++++------ .../stepSolvers/OlynmpusChatClientSolver.java | 2 +- 3 files changed, 31 insertions(+), 15 deletions(-) diff --git a/src/main/java/com/olympus/hermione/models/AiModel.java b/src/main/java/com/olympus/hermione/models/AiModel.java index 0dff9f1..730a791 100644 --- a/src/main/java/com/olympus/hermione/models/AiModel.java +++ b/src/main/java/com/olympus/hermione/models/AiModel.java @@ -20,5 +20,6 @@ public class AiModel { private String apiKey; private boolean isDefault = false; private boolean isCanvasDefault = false; + private String full_path_endpoint; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java index c09452f..1b9c21f 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -5,11 +5,14 @@ import java.util.List; import java.util.stream.Collectors; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; import org.springframework.ai.document.Document; import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SearchRequest.Builder; +import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.GoogleGeminiChatClient; +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.utility.AttributeParser; @@ -72,16 +75,17 @@ public class AdvancedQueryRagSolver extends StepSolver { } if(this.step.getAttributes().containsKey("rephrase_prompt")){ - this.rephrasePrompt = (String) this.step.getAttributes().get("rephrase_prompt"); + this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt")); }else{ this.rephrasePrompt = "Please rewrite the following query to be more efficent in a similarity search context : " + this.query; + this.rephrasePrompt += "\n\n The output must be a single query and nothing else."; } if(this.step.getAttributes().containsKey("multi_phrase_prompt")){ - this.MultiPhrasePrompt = (String) this.step.getAttributes().get("multi_phrase_prompt"); + this.MultiPhrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("multi_phrase_prompt")); }else{ this.MultiPhrasePrompt = "Please rewrite the following query in 3 different way that could be used to perform a similarity search in a RAG scenario: " + this.query; - this.MultiPhrasePrompt += "\n\n The output must be a list of 3 different queries, one per line and nothing else."; + this.MultiPhrasePrompt += "\n\n The output must be a list of 3 different queries, one per line and nothing else. Do not put line numbers, just the queries."; } this.outputField = (String) this.step.getAttributes().get("rag_output_variable"); @@ -109,22 +113,33 @@ public class AdvancedQueryRagSolver extends StepSolver { logParameters(); String elaboratedQuery = this.query; + OlympusChatClient chatClient = null; + + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { + chatClient = new AzureOpenApiChatClient(); + } + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) { + chatClient = new GoogleGeminiChatClient(); + } + + + chatClient.init(scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(), + scenarioExecution.getScenario().getAiModel().getApiKey(), + scenarioExecution.getScenario().getAiModel().getMaxTokens()); + + + if( this.enableRephrase ){ - CallResponseSpec resp = chatClient.prompt() - .user(this.rephrasePrompt) - .call(); - - elaboratedQuery = resp.content(); + OlympusChatClientResponse resp = chatClient.getChatCompletion(this.rephrasePrompt); + elaboratedQuery = resp.getContent(); } if( this.enableMultiPhrase ){ - CallResponseSpec resp = chatClient.prompt() - .user(this.MultiPhrasePrompt) - .call(); - - elaboratedQuery = resp.content(); + OlympusChatClientResponse resp = chatClient.getChatCompletion(this.MultiPhrasePrompt); + elaboratedQuery = resp.getContent(); + } List docs = new ArrayList(); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java index c0e0985..733c499 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -59,7 +59,7 @@ public class OlynmpusChatClientSolver extends StepSolver{ } - chatClient.init(scenarioExecution.getScenario().getAiModel().getEndpoint(), + chatClient.init(scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(), scenarioExecution.getScenario().getAiModel().getApiKey(), scenarioExecution.getScenario().getAiModel().getMaxTokens()); String userText = this.qai_user_input;