Aggiungi AdvancedQueryRagSolver per la gestione avanzata delle query RAG
This commit is contained in:
@@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
|
||||
import com.olympus.hermione.repository.ScenarioRepository;
|
||||
import com.olympus.hermione.security.entity.User;
|
||||
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
|
||||
import com.olympus.hermione.stepSolvers.AdvancedQueryRagSolver;
|
||||
import com.olympus.hermione.stepSolvers.SummarizeDocSolver;
|
||||
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
|
||||
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
|
||||
@@ -250,6 +251,9 @@ public class ScenarioExecutionService {
|
||||
case "SIMPLE_QUERY_AI":
|
||||
solver = new BasicAIPromptSolver();
|
||||
break;
|
||||
case "ADVANCED_RAG_QUERY":
|
||||
solver = new AdvancedQueryRagSolver();
|
||||
break;
|
||||
case "QUERY_GRAPH":
|
||||
solver = new QueryNeo4JSolver();
|
||||
break;
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.ai.vectorstore.SearchRequest.Builder;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
public class AdvancedQueryRagSolver extends StepSolver {
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||
|
||||
// PARAMETERS
|
||||
// rag_input : the field of the sharedMap to use as rag input
|
||||
// rag_filter : the field of the sharedMap to use as rag filter
|
||||
// rag_topk : the number of topk results to return
|
||||
// rag_threshold : the similarity threshold to use
|
||||
// rag_output_variable : the field of the sharedMap to store the output
|
||||
private String query;
|
||||
private String rag_filter;
|
||||
private int topk;
|
||||
private double threshold;
|
||||
private String outputField;
|
||||
private boolean enableRephrase;
|
||||
private boolean enableMultiPhrase;
|
||||
private String rephrasePrompt;
|
||||
private String MultiPhrasePrompt;
|
||||
|
||||
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input"));
|
||||
|
||||
this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter"));
|
||||
|
||||
if(this.step.getAttributes().containsKey("rag_topk")){
|
||||
this.topk = (int) this.step.getAttributes().get("rag_topk");
|
||||
}else{
|
||||
this.topk = 4;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("rag_threshold")){
|
||||
this.threshold = (double) this.step.getAttributes().get("rag_threshold");
|
||||
}else{
|
||||
this.threshold = 0.8;
|
||||
}
|
||||
|
||||
|
||||
if(this.step.getAttributes().containsKey("enable_rephrase")){
|
||||
this.enableRephrase = (boolean) this.step.getAttributes().get("enable_rephrase");
|
||||
}else{
|
||||
this.enableRephrase = false;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("enable_multi_phrase")){
|
||||
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
|
||||
}else{
|
||||
this.enableMultiPhrase = false;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
||||
this.rephrasePrompt = (String) this.step.getAttributes().get("rephrase_prompt");
|
||||
}else{
|
||||
this.rephrasePrompt = "Please rewrite the following query to be more efficent in a similarity search context : " + this.query;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("multi_phrase_prompt")){
|
||||
this.MultiPhrasePrompt = (String) this.step.getAttributes().get("multi_phrase_prompt");
|
||||
}else{
|
||||
this.MultiPhrasePrompt = "Please rewrite the following query in 3 different way that could be used to perform a similarity search in a RAG scenario: " + this.query;
|
||||
this.MultiPhrasePrompt += "\n\n The output must be a list of 3 different queries, one per line and nothing else.";
|
||||
}
|
||||
|
||||
this.outputField = (String) this.step.getAttributes().get("rag_output_variable");
|
||||
}
|
||||
|
||||
private void logParameters(){
|
||||
logger.info("query: " + this.query);
|
||||
logger.info("rag_filter: " + this.rag_filter);
|
||||
logger.info("topk: " + this.topk);
|
||||
logger.info("threshold: " + this.threshold);
|
||||
logger.info("outputField: " + this.outputField);
|
||||
logger.info("enableRephrase: " + this.enableRephrase);
|
||||
logger.info("enableMultiPhrase: " + this.enableMultiPhrase);
|
||||
logger.info("rephrasePrompt: " + this.rephrasePrompt);
|
||||
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public ScenarioExecution solveStep(){
|
||||
|
||||
logger.info("Solving step: " + this.step.getName());
|
||||
loadParameters();
|
||||
logParameters();
|
||||
String elaboratedQuery = this.query;
|
||||
|
||||
if( this.enableRephrase ){
|
||||
|
||||
CallResponseSpec resp = chatClient.prompt()
|
||||
.user(this.rephrasePrompt)
|
||||
.call();
|
||||
|
||||
elaboratedQuery = resp.content();
|
||||
}
|
||||
|
||||
if( this.enableMultiPhrase ){
|
||||
|
||||
CallResponseSpec resp = chatClient.prompt()
|
||||
.user(this.MultiPhrasePrompt)
|
||||
.call();
|
||||
|
||||
elaboratedQuery = resp.content();
|
||||
}
|
||||
|
||||
List<Document> docs = new ArrayList<Document>();
|
||||
|
||||
for (String query : elaboratedQuery.split("\n")) {
|
||||
|
||||
logger.info("Elaborated query: " + query);
|
||||
Builder request_builder = SearchRequest.builder()
|
||||
.query(elaboratedQuery)
|
||||
.topK(this.topk)
|
||||
.similarityThreshold(this.threshold);
|
||||
|
||||
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
|
||||
request_builder.filterExpression(this.rag_filter);
|
||||
logger.info("Using Filter expression: " + this.rag_filter);
|
||||
}
|
||||
|
||||
SearchRequest request = request_builder.build();
|
||||
docs.addAll( this.vectorStore.similaritySearch(request));
|
||||
}
|
||||
|
||||
logger.info("Number of not unique VDB retrieved documents: " + docs.size());
|
||||
|
||||
//Remove duplicates from docs using document id
|
||||
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
|
||||
|
||||
logger.info("Number of VDB retrieved documents: " + docs.size());
|
||||
|
||||
List<String> result = new ArrayList<String>();
|
||||
for (Document doc : docs) {
|
||||
result.add(doc.getText());
|
||||
}
|
||||
|
||||
ArrayList<String> source_doc = new ArrayList<String>();
|
||||
for (Document doc : docs) {
|
||||
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
|
||||
}
|
||||
|
||||
String resultString = String.join("\n", result);
|
||||
|
||||
this.scenarioExecution.getExecSharedMap().put("tech_rag_source_documents", source_doc);
|
||||
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user