Aggiungi supporto per il ranking dei documenti nelle query avanzate

This commit is contained in:
andrea.terzani
2025-03-28 12:08:27 +01:00
parent ba836d8c9f
commit 9a1ffa847d
3 changed files with 101 additions and 2 deletions

View File

@@ -15,6 +15,8 @@ import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.client.OlympusChatClientResponse;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import com.olympus.hermione.utility.RAGDocumentRanker;
import com.olympus.hermione.utility.RankedDocument;
import ch.qos.logback.classic.Logger;
@@ -37,6 +39,7 @@ public class AdvancedQueryRagSolver extends StepSolver {
private boolean enableMultiPhrase;
private String rephrasePrompt;
private String MultiPhrasePrompt;
private boolean enableRanking;
@@ -69,11 +72,17 @@ public class AdvancedQueryRagSolver extends StepSolver {
}
if(this.step.getAttributes().containsKey("enable_multi_phrase")){
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
}else{
this.enableMultiPhrase = false;
}
if(this.step.getAttributes().containsKey("enable_ranking")){
this.enableRanking = (boolean) this.step.getAttributes().get("enable_ranking");
}else{
this.enableRanking = false;
}
if(this.step.getAttributes().containsKey("rephrase_prompt")){
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
}else{
@@ -99,9 +108,9 @@ public class AdvancedQueryRagSolver extends StepSolver {
logger.info("outputField: " + this.outputField);
logger.info("enableRephrase: " + this.enableRephrase);
logger.info("enableMultiPhrase: " + this.enableMultiPhrase);
logger.info("enableRanking: " + this.enableRanking);
logger.info("rephrasePrompt: " + this.rephrasePrompt);
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
}
@@ -168,6 +177,19 @@ public class AdvancedQueryRagSolver extends StepSolver {
logger.info("Number of VDB retrieved documents: " + docs.size());
if (enableRanking){
List<Document> rankedDocs = new ArrayList<Document>();
logger.info("Ranking documents");
RAGDocumentRanker ranker = new RAGDocumentRanker();
List<RankedDocument> rankedDocument = ranker.rankDocuments(docs, this.query, 5, chatClient);
for (RankedDocument rankedDoc : rankedDocument) {
rankedDocs.add(rankedDoc.getDocument());
}
logger.info("Number of ranked documents: " + rankedDocs.size());
docs = rankedDocs;
}
List<String> source_doc = new ArrayList<String>();
String resultString = "";

View File

@@ -0,0 +1,47 @@
package com.olympus.hermione.utility;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.client.OlympusChatClientResponse;
public class RAGDocumentRanker {
Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class);
public List<RankedDocument> rankDocuments(List<Document> documents, String query, int threshold, OlympusChatClient chatClient) {
List<RankedDocument> rankedDocuments = new ArrayList<>();
for(Document document : documents){
logger.info("Ranking document: " + document.getId());
String rankingPrompt="On a scale of 1-10, rate the relevance of the following chunk of a bigger document to the query. Consider the specific context and intent of the query, not just keyword matches.";
rankingPrompt +="Query: " + query;
rankingPrompt +="Chunk: " + document.getText();
rankingPrompt +="Answer with a number between 1 and 10 and nothing else.";
OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt);
String rank = resp.getContent();
int rankInt = Integer.parseInt(rank);
logger.info("Rank: " + rankInt);
if(rankInt > threshold){
rankedDocuments.add(new RankedDocument(document, rankInt));
logger.info("Document ranked: " + document.getId());
}
}
return rankedDocuments;
}
}

View File

@@ -0,0 +1,30 @@
package com.olympus.hermione.utility;
import java.util.List;
import org.springframework.ai.document.Document;
public class RankedDocument {
private Document document;
private int rank;
public RankedDocument(Document document, int rank) {
this.document = document;
this.rank = rank;
}
public Document getDocument() {
return document;
}
public void setDocument(Document document) {
this.document = document;
}
public int getRank() {
return rank;
}
public void setRank(int rank) {
this.rank = rank;
}
}