Aggiungi supporto per il ranking dei documenti nelle query avanzate
This commit is contained in:
@@ -15,6 +15,8 @@ import com.olympus.hermione.client.OlympusChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
import com.olympus.hermione.utility.RAGDocumentRanker;
|
||||
import com.olympus.hermione.utility.RankedDocument;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
@@ -37,6 +39,7 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
private boolean enableMultiPhrase;
|
||||
private String rephrasePrompt;
|
||||
private String MultiPhrasePrompt;
|
||||
private boolean enableRanking;
|
||||
|
||||
|
||||
|
||||
@@ -69,11 +72,17 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("enable_multi_phrase")){
|
||||
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
|
||||
this.enableMultiPhrase = (boolean) this.step.getAttributes().get("enable_multi_phrase");
|
||||
}else{
|
||||
this.enableMultiPhrase = false;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("enable_ranking")){
|
||||
this.enableRanking = (boolean) this.step.getAttributes().get("enable_ranking");
|
||||
}else{
|
||||
this.enableRanking = false;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
||||
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
||||
}else{
|
||||
@@ -99,9 +108,9 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
logger.info("outputField: " + this.outputField);
|
||||
logger.info("enableRephrase: " + this.enableRephrase);
|
||||
logger.info("enableMultiPhrase: " + this.enableMultiPhrase);
|
||||
logger.info("enableRanking: " + this.enableRanking);
|
||||
logger.info("rephrasePrompt: " + this.rephrasePrompt);
|
||||
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -168,6 +177,19 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
|
||||
logger.info("Number of VDB retrieved documents: " + docs.size());
|
||||
|
||||
|
||||
if (enableRanking){
|
||||
List<Document> rankedDocs = new ArrayList<Document>();
|
||||
logger.info("Ranking documents");
|
||||
RAGDocumentRanker ranker = new RAGDocumentRanker();
|
||||
List<RankedDocument> rankedDocument = ranker.rankDocuments(docs, this.query, 5, chatClient);
|
||||
for (RankedDocument rankedDoc : rankedDocument) {
|
||||
rankedDocs.add(rankedDoc.getDocument());
|
||||
}
|
||||
logger.info("Number of ranked documents: " + rankedDocs.size());
|
||||
docs = rankedDocs;
|
||||
}
|
||||
|
||||
List<String> source_doc = new ArrayList<String>();
|
||||
String resultString = "";
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package com.olympus.hermione.utility;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.olympus.hermione.client.OlympusChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||
|
||||
|
||||
|
||||
|
||||
public class RAGDocumentRanker {
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class);
|
||||
|
||||
|
||||
public List<RankedDocument> rankDocuments(List<Document> documents, String query, int threshold, OlympusChatClient chatClient) {
|
||||
|
||||
List<RankedDocument> rankedDocuments = new ArrayList<>();
|
||||
|
||||
for(Document document : documents){
|
||||
logger.info("Ranking document: " + document.getId());
|
||||
String rankingPrompt="On a scale of 1-10, rate the relevance of the following chunk of a bigger document to the query. Consider the specific context and intent of the query, not just keyword matches.";
|
||||
rankingPrompt +="Query: " + query;
|
||||
rankingPrompt +="Chunk: " + document.getText();
|
||||
rankingPrompt +="Answer with a number between 1 and 10 and nothing else.";
|
||||
|
||||
OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt);
|
||||
String rank = resp.getContent();
|
||||
|
||||
int rankInt = Integer.parseInt(rank);
|
||||
logger.info("Rank: " + rankInt);
|
||||
if(rankInt > threshold){
|
||||
rankedDocuments.add(new RankedDocument(document, rankInt));
|
||||
logger.info("Document ranked: " + document.getId());
|
||||
}
|
||||
}
|
||||
|
||||
return rankedDocuments;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
package com.olympus.hermione.utility;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.ai.document.Document;
|
||||
|
||||
public class RankedDocument {
|
||||
|
||||
private Document document;
|
||||
private int rank;
|
||||
|
||||
public RankedDocument(Document document, int rank) {
|
||||
this.document = document;
|
||||
this.rank = rank;
|
||||
}
|
||||
public Document getDocument() {
|
||||
return document;
|
||||
}
|
||||
public void setDocument(Document document) {
|
||||
this.document = document;
|
||||
}
|
||||
public int getRank() {
|
||||
return rank;
|
||||
}
|
||||
public void setRank(int rank) {
|
||||
this.rank = rank;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user