From 717cb365ecee6c13e2c16c9e401212a802bad5c7 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Sat, 29 Mar 2025 09:36:16 +0100 Subject: [PATCH] ``` Nessuna modifica apportata ``` --- .../stepSolvers/AdvancedQueryRagSolver.java | 89 +++++++++++++++++-- 1 file changed, 84 insertions(+), 5 deletions(-) diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java index 2d898d8..537f1b9 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -43,7 +43,10 @@ public class AdvancedQueryRagSolver extends StepSolver { private boolean enableRetrieveOther; private int rank_threshold; private int otherTopK; - + private boolean enableNeighborRetrieve; + private int preNeighbor; + private int afterNeighbor; + private void loadParameters(){ logger.info("Loading parameters"); @@ -101,7 +104,22 @@ public class AdvancedQueryRagSolver extends StepSolver { }else{ this.enableRetrieveOther = false; } - + + if(this.step.getAttributes().containsKey("enable_neighbor_retrieve")){ + this.enableNeighborRetrieve = (boolean) this.step.getAttributes().get("enable_neighbor_retrieve"); + }else{ + this.enableNeighborRetrieve = false; + } + if(this.step.getAttributes().containsKey("pre_neighbor")){ + this.preNeighbor = (int) this.step.getAttributes().get("pre_neighbor"); + }else{ + this.preNeighbor = 1; + } + if(this.step.getAttributes().containsKey("after_neighbor")){ + this.afterNeighbor = (int) this.step.getAttributes().get("after_neighbor"); + }else{ + this.afterNeighbor = 1; + } if(this.step.getAttributes().containsKey("rephrase_prompt")){ this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt")); @@ -131,6 +149,12 @@ public class AdvancedQueryRagSolver extends StepSolver { logger.info("enableRanking: " + this.enableRanking); logger.info("rephrasePrompt: " + this.rephrasePrompt); logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt); + logger.info("rank_threshold: " + this.rank_threshold); + logger.info("otherTopK: " + this.otherTopK); + logger.info("enableRetrieveOther: " + this.enableRetrieveOther); + logger.info("enableNeighborRetrieve: " + this.enableNeighborRetrieve); + logger.info("preNeighbor: " + this.preNeighbor); + logger.info("afterNeighbor: " + this.afterNeighbor); } @@ -210,6 +234,20 @@ public class AdvancedQueryRagSolver extends StepSolver { docs = rankedDocs; } + if(enableNeighborRetrieve){ + logger.info("Retrieving neighbor documents: #PRE " + this.preNeighbor + " #POST " + this.afterNeighbor); + List neighborDocs = new ArrayList(); + for(Document doc : docs){ + neighborDocs.addAll( retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor)); + logger.info("Number of neighbor documents: " + neighborDocs.size()); + } + docs.addAll(neighborDocs); + logger.info("Number of documents after neighbor retrieval: " + docs.size()); + + } + + + if(enableRetrieveOther){ List otherDocs = new ArrayList(); logger.info("Retrieving other documents"); @@ -254,11 +292,26 @@ public class AdvancedQueryRagSolver extends StepSolver { docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList()); + //Sort by KsDocumentId and KsDocumentIndex + docs.sort((d1, d2) -> { + String docId1 = (String) d1.getMetadata().get("KsDocumentId"); + String docId2 = (String) d2.getMetadata().get("KsDocumentId"); + int cmp = docId1.compareTo(docId2); + if (cmp != 0) { + return cmp; + } else { + int idx1 = Integer.parseInt((String)d1.getMetadata().get("KsDocumentIndex")); + int idx2 = Integer.parseInt((String)d2.getMetadata().get("KsDocumentIndex")); + return Integer.compare(idx1, idx2); + } + }); + + List source_doc = new ArrayList(); String resultString = ""; for(Document doc : docs){ - resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n"; + resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + " - Chunk Index " + doc.getMetadata().get("KsDocumentIndex")+ "\n"; resultString += "--------------------------------------------------------------\n"; resultString += doc.getText() + "\n"; source_doc.add((String)doc.getMetadata().get("KsFileSource")); @@ -274,6 +327,32 @@ public class AdvancedQueryRagSolver extends StepSolver { return this.scenarioExecution; } - - + private List retrieveNeighborDocuments(Document doc, int preNeighbor, int afterNeighbor){ + + int currentIdx = Integer.parseInt((String)doc.getMetadata().get("KsDocumentIndex")); + String KsDocumentId = (String)doc.getMetadata().get("KsDocumentId"); + + int startIdx = currentIdx - preNeighbor; + int endIdx = currentIdx + afterNeighbor; + List otherDocs = new ArrayList(); + logger.info("Retrieving neighbors for: " + KsDocumentId); + logger.info("Retrieving other documents for neighbor index: " + startIdx + " to " + endIdx); + for(int i = startIdx; i <= endIdx; i++){ + if(i == currentIdx){ + continue; + } + String queryExpression = "KsDocumentId == '" + KsDocumentId + "' AND KsDocumentIndex == '" + i+"'"; + logger.info("Query expression: " + queryExpression); + SearchRequest request = SearchRequest.builder() + .query("*") + .topK(1) + .similarityThreshold(0) + .filterExpression(queryExpression).build(); + + otherDocs.addAll( this.vectorStore.similaritySearch(request)); + + } + logger.info("Number of neighbor retrieved : " + otherDocs.size()); + return otherDocs; + } }