diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java index 2b503a3..31e6437 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -40,8 +40,9 @@ public class AdvancedQueryRagSolver extends StepSolver { private String rephrasePrompt; private String MultiPhrasePrompt; private boolean enableRanking; - - + private boolean enableRetrieveOther; + private int rank_threshold; + private int otherTopK; private void loadParameters(){ logger.info("Loading parameters"); @@ -83,6 +84,25 @@ public class AdvancedQueryRagSolver extends StepSolver { this.enableRanking = false; } + if(this.step.getAttributes().containsKey("rank_threshold")){ + this.rank_threshold = (int) this.step.getAttributes().get("rank_threshold"); + }else{ + this.rank_threshold = 5; + } + + if(this.step.getAttributes().containsKey("otherTopK")){ + this.otherTopK = (int) this.step.getAttributes().get("otherTopK"); + }else{ + this.otherTopK = 2; + } + + if(this.step.getAttributes().containsKey("enable_retrieve_other")){ + this.enableRetrieveOther = (boolean) this.step.getAttributes().get("enable_retrieve_other"); + }else{ + this.enableRetrieveOther = false; + } + + if(this.step.getAttributes().containsKey("rephrase_prompt")){ this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt")); }else{ @@ -182,7 +202,7 @@ public class AdvancedQueryRagSolver extends StepSolver { List rankedDocs = new ArrayList(); logger.info("Ranking documents"); RAGDocumentRanker ranker = new RAGDocumentRanker(); - List rankedDocument = ranker.rankDocuments(docs, this.query, 5, chatClient); + List rankedDocument = ranker.rankDocuments(docs, this.query, this.rank_threshold, chatClient); for (RankedDocument rankedDoc : rankedDocument) { rankedDocs.add(rankedDoc.getDocument()); } @@ -190,6 +210,50 @@ public class AdvancedQueryRagSolver extends StepSolver { docs = rankedDocs; } + if(enableRetrieveOther){ + List otherDocs = new ArrayList(); + logger.info("Retrieving other documents"); + + String resultString = ""; + + for(Document doc : docs){ + resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n"; + resultString += "--------------------------------------------------------------\n"; + resultString += doc.getText() + "\n"; + } + + String prompt = "I've retrieved the following chunks of document using similarity search."; + prompt +="Can you provide me a query that can be used with another vector similarity search to retrieve other relevant documents to answer the query?"; + prompt += "\n\n" + resultString; + prompt += "\n\n The query is: " + this.query; + prompt += "\n\n The output must be a single query and nothing else."; + + OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt); + + String otherDocQuery = resp.getContent(); + logger.info("Other document query: " + otherDocQuery); + + Builder request_builder = SearchRequest.builder() + .query(otherDocQuery) + .topK(otherTopK) + .similarityThreshold(0.8); + + if(this.rag_filter != null && !this.rag_filter.isEmpty()){ + request_builder.filterExpression(this.rag_filter); + logger.info("Using Filter expression: " + this.rag_filter); + } + + SearchRequest request = request_builder.build(); + otherDocs.addAll( this.vectorStore.similaritySearch(request)); + logger.info("Number of addictional documents: " + otherDocs.size()); + + docs.addAll(otherDocs); + + } + + + docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList()); + List source_doc = new ArrayList(); String resultString = "";