From 4fe9d9f356346c3bd14ec38f5548d47a20e18207 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Tue, 25 Mar 2025 09:21:23 +0100 Subject: [PATCH] Aggiungi AdvancedQueryRagSolver per la gestione avanzata delle query RAG --- .../services/ScenarioExecutionService.java | 4 + .../stepSolvers/AdvancedQueryRagSolver.java | 178 ++++++++++++++++++ 2 files changed, 182 insertions(+) create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 4d25613..a03a2c3 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository; import com.olympus.hermione.repository.ScenarioRepository; import com.olympus.hermione.security.entity.User; import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver; +import com.olympus.hermione.stepSolvers.AdvancedQueryRagSolver; import com.olympus.hermione.stepSolvers.SummarizeDocSolver; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; @@ -250,6 +251,9 @@ public class ScenarioExecutionService { case "SIMPLE_QUERY_AI": solver = new BasicAIPromptSolver(); break; + case "ADVANCED_RAG_QUERY": + solver = new AdvancedQueryRagSolver(); + break; case "QUERY_GRAPH": solver = new QueryNeo4JSolver(); break; diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java new file mode 100644 index 0000000..f292190 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -0,0 +1,178 @@ +package com.olympus.hermione.stepSolvers; + +import java.util.ArrayList; +import java.util.List; +import java.util.stream.Collectors; + +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.SearchRequest.Builder; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class AdvancedQueryRagSolver extends StepSolver { + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + // PARAMETERS + // rag_input : the field of the sharedMap to use as rag input + // rag_filter : the field of the sharedMap to use as rag filter + // rag_topk : the number of topk results to return + // rag_threshold : the similarity threshold to use + // rag_output_variable : the field of the sharedMap to store the output + private String query; + private String rag_filter; + private int topk; + private double threshold; + private String outputField; + private boolean enableRephrase; + private boolean enableMultiPhrase; + private String rephrasePrompt; + private String MultiPhrasePrompt; + + + + private void loadParameters(){ + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input")); + + this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter")); + + if(this.step.getAttributes().containsKey("rag_topk")){ + this.topk = (int) this.step.getAttributes().get("rag_topk"); + }else{ + this.topk = 4; + } + + if(this.step.getAttributes().containsKey("rag_threshold")){ + this.threshold = (double) this.step.getAttributes().get("rag_threshold"); + }else{ + this.threshold = 0.8; + } + + + if(this.step.getAttributes().containsKey("enable_rephrase")){ + this.enableRephrase = (boolean) this.step.getAttributes().get("enable_rephrase"); + }else{ + this.enableRephrase = false; + } + + if(this.step.getAttributes().containsKey("enable_multi_phrase")){ + this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase"); + }else{ + this.enableMultiPhrase = false; + } + + if(this.step.getAttributes().containsKey("rephrase_prompt")){ + this.rephrasePrompt = (String) this.step.getAttributes().get("rephrase_prompt"); + }else{ + this.rephrasePrompt = "Please rewrite the following query to be more efficent in a similarity search context : " + this.query; + } + + if(this.step.getAttributes().containsKey("multi_phrase_prompt")){ + this.MultiPhrasePrompt = (String) this.step.getAttributes().get("multi_phrase_prompt"); + }else{ + this.MultiPhrasePrompt = "Please rewrite the following query in 3 different way that could be used to perform a similarity search in a RAG scenario: " + this.query; + this.MultiPhrasePrompt += "\n\n The output must be a list of 3 different queries, one per line and nothing else."; + } + + this.outputField = (String) this.step.getAttributes().get("rag_output_variable"); + } + + private void logParameters(){ + logger.info("query: " + this.query); + logger.info("rag_filter: " + this.rag_filter); + logger.info("topk: " + this.topk); + logger.info("threshold: " + this.threshold); + logger.info("outputField: " + this.outputField); + logger.info("enableRephrase: " + this.enableRephrase); + logger.info("enableMultiPhrase: " + this.enableMultiPhrase); + logger.info("rephrasePrompt: " + this.rephrasePrompt); + logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt); + + } + + + @Override + public ScenarioExecution solveStep(){ + + logger.info("Solving step: " + this.step.getName()); + loadParameters(); + logParameters(); + String elaboratedQuery = this.query; + + if( this.enableRephrase ){ + + CallResponseSpec resp = chatClient.prompt() + .user(this.rephrasePrompt) + .call(); + + elaboratedQuery = resp.content(); + } + + if( this.enableMultiPhrase ){ + + CallResponseSpec resp = chatClient.prompt() + .user(this.MultiPhrasePrompt) + .call(); + + elaboratedQuery = resp.content(); + } + + List docs = new ArrayList(); + + for (String query : elaboratedQuery.split("\n")) { + + logger.info("Elaborated query: " + query); + Builder request_builder = SearchRequest.builder() + .query(elaboratedQuery) + .topK(this.topk) + .similarityThreshold(this.threshold); + + if(this.rag_filter != null && !this.rag_filter.isEmpty()){ + request_builder.filterExpression(this.rag_filter); + logger.info("Using Filter expression: " + this.rag_filter); + } + + SearchRequest request = request_builder.build(); + docs.addAll( this.vectorStore.similaritySearch(request)); + } + + logger.info("Number of not unique VDB retrieved documents: " + docs.size()); + + //Remove duplicates from docs using document id + docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList()); + + logger.info("Number of VDB retrieved documents: " + docs.size()); + + List result = new ArrayList(); + for (Document doc : docs) { + result.add(doc.getText()); + } + + ArrayList source_doc = new ArrayList(); + for (Document doc : docs) { + source_doc.add((String)doc.getMetadata().get("KsFileSource")); + } + + String resultString = String.join("\n", result); + + this.scenarioExecution.getExecSharedMap().put("tech_rag_source_documents", source_doc); + this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + + + +}