From 9a1ffa847d6404d6fba27afea0de4f65fd4fda9b Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Fri, 28 Mar 2025 12:08:27 +0100 Subject: [PATCH] Aggiungi supporto per il ranking dei documenti nelle query avanzate --- .../stepSolvers/AdvancedQueryRagSolver.java | 26 +++++++++- .../hermione/utility/RAGDocumentRanker.java | 47 +++++++++++++++++++ .../hermione/utility/RankedDocument.java | 30 ++++++++++++ 3 files changed, 101 insertions(+), 2 deletions(-) create mode 100644 src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java create mode 100644 src/main/java/com/olympus/hermione/utility/RankedDocument.java diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java index 283841c..2b503a3 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -15,6 +15,8 @@ import com.olympus.hermione.client.OlympusChatClient; import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.utility.AttributeParser; +import com.olympus.hermione.utility.RAGDocumentRanker; +import com.olympus.hermione.utility.RankedDocument; import ch.qos.logback.classic.Logger; @@ -37,6 +39,7 @@ public class AdvancedQueryRagSolver extends StepSolver { private boolean enableMultiPhrase; private String rephrasePrompt; private String MultiPhrasePrompt; + private boolean enableRanking; @@ -69,11 +72,17 @@ public class AdvancedQueryRagSolver extends StepSolver { } if(this.step.getAttributes().containsKey("enable_multi_phrase")){ - this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase"); + this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase"); }else{ this.enableMultiPhrase = false; } + if(this.step.getAttributes().containsKey("enable_ranking")){ + this.enableRanking = (boolean) this.step.getAttributes().get("enable_ranking"); + }else{ + this.enableRanking = false; + } + if(this.step.getAttributes().containsKey("rephrase_prompt")){ this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt")); }else{ @@ -99,9 +108,9 @@ public class AdvancedQueryRagSolver extends StepSolver { logger.info("outputField: " + this.outputField); logger.info("enableRephrase: " + this.enableRephrase); logger.info("enableMultiPhrase: " + this.enableMultiPhrase); + logger.info("enableRanking: " + this.enableRanking); logger.info("rephrasePrompt: " + this.rephrasePrompt); logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt); - } @@ -168,6 +177,19 @@ public class AdvancedQueryRagSolver extends StepSolver { logger.info("Number of VDB retrieved documents: " + docs.size()); + + if (enableRanking){ + List rankedDocs = new ArrayList(); + logger.info("Ranking documents"); + RAGDocumentRanker ranker = new RAGDocumentRanker(); + List rankedDocument = ranker.rankDocuments(docs, this.query, 5, chatClient); + for (RankedDocument rankedDoc : rankedDocument) { + rankedDocs.add(rankedDoc.getDocument()); + } + logger.info("Number of ranked documents: " + rankedDocs.size()); + docs = rankedDocs; + } + List source_doc = new ArrayList(); String resultString = ""; diff --git a/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java new file mode 100644 index 0000000..01c5230 --- /dev/null +++ b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java @@ -0,0 +1,47 @@ +package com.olympus.hermione.utility; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.stereotype.Service; + +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; + + + + +public class RAGDocumentRanker { + + Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class); + + + public List rankDocuments(List documents, String query, int threshold, OlympusChatClient chatClient) { + + List rankedDocuments = new ArrayList<>(); + + for(Document document : documents){ + logger.info("Ranking document: " + document.getId()); + String rankingPrompt="On a scale of 1-10, rate the relevance of the following chunk of a bigger document to the query. Consider the specific context and intent of the query, not just keyword matches."; + rankingPrompt +="Query: " + query; + rankingPrompt +="Chunk: " + document.getText(); + rankingPrompt +="Answer with a number between 1 and 10 and nothing else."; + + OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt); + String rank = resp.getContent(); + + int rankInt = Integer.parseInt(rank); + logger.info("Rank: " + rankInt); + if(rankInt > threshold){ + rankedDocuments.add(new RankedDocument(document, rankInt)); + logger.info("Document ranked: " + document.getId()); + } + } + + return rankedDocuments; + } +} + diff --git a/src/main/java/com/olympus/hermione/utility/RankedDocument.java b/src/main/java/com/olympus/hermione/utility/RankedDocument.java new file mode 100644 index 0000000..69b417b --- /dev/null +++ b/src/main/java/com/olympus/hermione/utility/RankedDocument.java @@ -0,0 +1,30 @@ +package com.olympus.hermione.utility; + +import java.util.List; + +import org.springframework.ai.document.Document; + +public class RankedDocument { + + private Document document; + private int rank; + + public RankedDocument(Document document, int rank) { + this.document = document; + this.rank = rank; + } + public Document getDocument() { + return document; + } + public void setDocument(Document document) { + this.document = document; + } + public int getRank() { + return rank; + } + public void setRank(int rank) { + this.rank = rank; + } + + +}