Aggiungi AdvancedQueryRagSolver per la gestione avanzata delle query RAG

This commit is contained in:
andrea.terzani
2025-03-25 09:21:23 +01:00
parent 4cb34c0087
commit 4fe9d9f356
2 changed files with 182 additions and 0 deletions

View File

@@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
import com.olympus.hermione.security.entity.User;
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
import com.olympus.hermione.stepSolvers.AdvancedQueryRagSolver;
import com.olympus.hermione.stepSolvers.SummarizeDocSolver;
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
@@ -250,6 +251,9 @@ public class ScenarioExecutionService {
case "SIMPLE_QUERY_AI":
solver = new BasicAIPromptSolver();
break;
case "ADVANCED_RAG_QUERY":
solver = new AdvancedQueryRagSolver();
break;
case "QUERY_GRAPH":
solver = new QueryNeo4JSolver();
break;

View File

@@ -0,0 +1,178 @@
package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SearchRequest.Builder;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class AdvancedQueryRagSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
// PARAMETERS
// rag_input : the field of the sharedMap to use as rag input
// rag_filter : the field of the sharedMap to use as rag filter
// rag_topk : the number of topk results to return
// rag_threshold : the similarity threshold to use
// rag_output_variable : the field of the sharedMap to store the output
private String query;
private String rag_filter;
private int topk;
private double threshold;
private String outputField;
private boolean enableRephrase;
private boolean enableMultiPhrase;
private String rephrasePrompt;
private String MultiPhrasePrompt;
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input"));
this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter"));
if(this.step.getAttributes().containsKey("rag_topk")){
this.topk = (int) this.step.getAttributes().get("rag_topk");
}else{
this.topk = 4;
}
if(this.step.getAttributes().containsKey("rag_threshold")){
this.threshold = (double) this.step.getAttributes().get("rag_threshold");
}else{
this.threshold = 0.8;
}
if(this.step.getAttributes().containsKey("enable_rephrase")){
this.enableRephrase = (boolean) this.step.getAttributes().get("enable_rephrase");
}else{
this.enableRephrase = false;
}
if(this.step.getAttributes().containsKey("enable_multi_phrase")){
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
}else{
this.enableMultiPhrase = false;
}
if(this.step.getAttributes().containsKey("rephrase_prompt")){
this.rephrasePrompt = (String) this.step.getAttributes().get("rephrase_prompt");
}else{
this.rephrasePrompt = "Please rewrite the following query to be more efficent in a similarity search context : " + this.query;
}
if(this.step.getAttributes().containsKey("multi_phrase_prompt")){
this.MultiPhrasePrompt = (String) this.step.getAttributes().get("multi_phrase_prompt");
}else{
this.MultiPhrasePrompt = "Please rewrite the following query in 3 different way that could be used to perform a similarity search in a RAG scenario: " + this.query;
this.MultiPhrasePrompt += "\n\n The output must be a list of 3 different queries, one per line and nothing else.";
}
this.outputField = (String) this.step.getAttributes().get("rag_output_variable");
}
private void logParameters(){
logger.info("query: " + this.query);
logger.info("rag_filter: " + this.rag_filter);
logger.info("topk: " + this.topk);
logger.info("threshold: " + this.threshold);
logger.info("outputField: " + this.outputField);
logger.info("enableRephrase: " + this.enableRephrase);
logger.info("enableMultiPhrase: " + this.enableMultiPhrase);
logger.info("rephrasePrompt: " + this.rephrasePrompt);
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
}
@Override
public ScenarioExecution solveStep(){
logger.info("Solving step: " + this.step.getName());
loadParameters();
logParameters();
String elaboratedQuery = this.query;
if( this.enableRephrase ){
CallResponseSpec resp = chatClient.prompt()
.user(this.rephrasePrompt)
.call();
elaboratedQuery = resp.content();
}
if( this.enableMultiPhrase ){
CallResponseSpec resp = chatClient.prompt()
.user(this.MultiPhrasePrompt)
.call();
elaboratedQuery = resp.content();
}
List<Document> docs = new ArrayList<Document>();
for (String query : elaboratedQuery.split("\n")) {
logger.info("Elaborated query: " + query);
Builder request_builder = SearchRequest.builder()
.query(elaboratedQuery)
.topK(this.topk)
.similarityThreshold(this.threshold);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: " + this.rag_filter);
}
SearchRequest request = request_builder.build();
docs.addAll( this.vectorStore.similaritySearch(request));
}
logger.info("Number of not unique VDB retrieved documents: " + docs.size());
//Remove duplicates from docs using document id
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
logger.info("Number of VDB retrieved documents: " + docs.size());
List<String> result = new ArrayList<String>();
for (Document doc : docs) {
result.add(doc.getText());
}
ArrayList<String> source_doc = new ArrayList<String>();
for (Document doc : docs) {
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
}
String resultString = String.join("\n", result);
this.scenarioExecution.getExecSharedMap().put("tech_rag_source_documents", source_doc);
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution;
}
}