Nessuna modifica apportata
This commit is contained in:
@@ -40,8 +40,9 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
private String rephrasePrompt;
|
private String rephrasePrompt;
|
||||||
private String MultiPhrasePrompt;
|
private String MultiPhrasePrompt;
|
||||||
private boolean enableRanking;
|
private boolean enableRanking;
|
||||||
|
private boolean enableRetrieveOther;
|
||||||
|
private int rank_threshold;
|
||||||
|
private int otherTopK;
|
||||||
|
|
||||||
private void loadParameters(){
|
private void loadParameters(){
|
||||||
logger.info("Loading parameters");
|
logger.info("Loading parameters");
|
||||||
@@ -83,6 +84,25 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
this.enableRanking = false;
|
this.enableRanking = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(this.step.getAttributes().containsKey("rank_threshold")){
|
||||||
|
this.rank_threshold = (int) this.step.getAttributes().get("rank_threshold");
|
||||||
|
}else{
|
||||||
|
this.rank_threshold = 5;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(this.step.getAttributes().containsKey("otherTopK")){
|
||||||
|
this.otherTopK = (int) this.step.getAttributes().get("otherTopK");
|
||||||
|
}else{
|
||||||
|
this.otherTopK = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
if(this.step.getAttributes().containsKey("enable_retrieve_other")){
|
||||||
|
this.enableRetrieveOther = (boolean) this.step.getAttributes().get("enable_retrieve_other");
|
||||||
|
}else{
|
||||||
|
this.enableRetrieveOther = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
||||||
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
||||||
}else{
|
}else{
|
||||||
@@ -182,7 +202,7 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
List<Document> rankedDocs = new ArrayList<Document>();
|
List<Document> rankedDocs = new ArrayList<Document>();
|
||||||
logger.info("Ranking documents");
|
logger.info("Ranking documents");
|
||||||
RAGDocumentRanker ranker = new RAGDocumentRanker();
|
RAGDocumentRanker ranker = new RAGDocumentRanker();
|
||||||
List<RankedDocument> rankedDocument = ranker.rankDocuments(docs, this.query, 5, chatClient);
|
List<RankedDocument> rankedDocument = ranker.rankDocuments(docs, this.query, this.rank_threshold, chatClient);
|
||||||
for (RankedDocument rankedDoc : rankedDocument) {
|
for (RankedDocument rankedDoc : rankedDocument) {
|
||||||
rankedDocs.add(rankedDoc.getDocument());
|
rankedDocs.add(rankedDoc.getDocument());
|
||||||
}
|
}
|
||||||
@@ -190,6 +210,50 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
docs = rankedDocs;
|
docs = rankedDocs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(enableRetrieveOther){
|
||||||
|
List<Document> otherDocs = new ArrayList<Document>();
|
||||||
|
logger.info("Retrieving other documents");
|
||||||
|
|
||||||
|
String resultString = "";
|
||||||
|
|
||||||
|
for(Document doc : docs){
|
||||||
|
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n";
|
||||||
|
resultString += "--------------------------------------------------------------\n";
|
||||||
|
resultString += doc.getText() + "\n";
|
||||||
|
}
|
||||||
|
|
||||||
|
String prompt = "I've retrieved the following chunks of document using similarity search.";
|
||||||
|
prompt +="Can you provide me a query that can be used with another vector similarity search to retrieve other relevant documents to answer the query?";
|
||||||
|
prompt += "\n\n" + resultString;
|
||||||
|
prompt += "\n\n The query is: " + this.query;
|
||||||
|
prompt += "\n\n The output must be a single query and nothing else.";
|
||||||
|
|
||||||
|
OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt);
|
||||||
|
|
||||||
|
String otherDocQuery = resp.getContent();
|
||||||
|
logger.info("Other document query: " + otherDocQuery);
|
||||||
|
|
||||||
|
Builder request_builder = SearchRequest.builder()
|
||||||
|
.query(otherDocQuery)
|
||||||
|
.topK(otherTopK)
|
||||||
|
.similarityThreshold(0.8);
|
||||||
|
|
||||||
|
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
|
||||||
|
request_builder.filterExpression(this.rag_filter);
|
||||||
|
logger.info("Using Filter expression: " + this.rag_filter);
|
||||||
|
}
|
||||||
|
|
||||||
|
SearchRequest request = request_builder.build();
|
||||||
|
otherDocs.addAll( this.vectorStore.similaritySearch(request));
|
||||||
|
logger.info("Number of addictional documents: " + otherDocs.size());
|
||||||
|
|
||||||
|
docs.addAll(otherDocs);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
|
||||||
|
|
||||||
List<String> source_doc = new ArrayList<String>();
|
List<String> source_doc = new ArrayList<String>();
|
||||||
String resultString = "";
|
String resultString = "";
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user