Nessuna modifica apportata
```
This commit is contained in:
andrea.terzani
2025-03-29 09:36:16 +01:00
parent 8f14260a9b
commit 717cb365ec

View File

@@ -43,7 +43,10 @@ public class AdvancedQueryRagSolver extends StepSolver {
private boolean enableRetrieveOther;
private int rank_threshold;
private int otherTopK;
private boolean enableNeighborRetrieve;
private int preNeighbor;
private int afterNeighbor;
private void loadParameters(){
logger.info("Loading parameters");
@@ -101,7 +104,22 @@ public class AdvancedQueryRagSolver extends StepSolver {
}else{
this.enableRetrieveOther = false;
}
if(this.step.getAttributes().containsKey("enable_neighbor_retrieve")){
this.enableNeighborRetrieve = (boolean) this.step.getAttributes().get("enable_neighbor_retrieve");
}else{
this.enableNeighborRetrieve = false;
}
if(this.step.getAttributes().containsKey("pre_neighbor")){
this.preNeighbor = (int) this.step.getAttributes().get("pre_neighbor");
}else{
this.preNeighbor = 1;
}
if(this.step.getAttributes().containsKey("after_neighbor")){
this.afterNeighbor = (int) this.step.getAttributes().get("after_neighbor");
}else{
this.afterNeighbor = 1;
}
if(this.step.getAttributes().containsKey("rephrase_prompt")){
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
@@ -131,6 +149,12 @@ public class AdvancedQueryRagSolver extends StepSolver {
logger.info("enableRanking: " + this.enableRanking);
logger.info("rephrasePrompt: " + this.rephrasePrompt);
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
logger.info("rank_threshold: " + this.rank_threshold);
logger.info("otherTopK: " + this.otherTopK);
logger.info("enableRetrieveOther: " + this.enableRetrieveOther);
logger.info("enableNeighborRetrieve: " + this.enableNeighborRetrieve);
logger.info("preNeighbor: " + this.preNeighbor);
logger.info("afterNeighbor: " + this.afterNeighbor);
}
@@ -210,6 +234,20 @@ public class AdvancedQueryRagSolver extends StepSolver {
docs = rankedDocs;
}
if(enableNeighborRetrieve){
logger.info("Retrieving neighbor documents: #PRE " + this.preNeighbor + " #POST " + this.afterNeighbor);
List<Document> neighborDocs = new ArrayList<Document>();
for(Document doc : docs){
neighborDocs.addAll( retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor));
logger.info("Number of neighbor documents: " + neighborDocs.size());
}
docs.addAll(neighborDocs);
logger.info("Number of documents after neighbor retrieval: " + docs.size());
}
if(enableRetrieveOther){
List<Document> otherDocs = new ArrayList<Document>();
logger.info("Retrieving other documents");
@@ -254,11 +292,26 @@ public class AdvancedQueryRagSolver extends StepSolver {
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
//Sort by KsDocumentId and KsDocumentIndex
docs.sort((d1, d2) -> {
String docId1 = (String) d1.getMetadata().get("KsDocumentId");
String docId2 = (String) d2.getMetadata().get("KsDocumentId");
int cmp = docId1.compareTo(docId2);
if (cmp != 0) {
return cmp;
} else {
int idx1 = Integer.parseInt((String)d1.getMetadata().get("KsDocumentIndex"));
int idx2 = Integer.parseInt((String)d2.getMetadata().get("KsDocumentIndex"));
return Integer.compare(idx1, idx2);
}
});
List<String> source_doc = new ArrayList<String>();
String resultString = "";
for(Document doc : docs){
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n";
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + " - Chunk Index " + doc.getMetadata().get("KsDocumentIndex")+ "\n";
resultString += "--------------------------------------------------------------\n";
resultString += doc.getText() + "\n";
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
@@ -274,6 +327,32 @@ public class AdvancedQueryRagSolver extends StepSolver {
return this.scenarioExecution;
}
private List<Document> retrieveNeighborDocuments(Document doc, int preNeighbor, int afterNeighbor){
int currentIdx = Integer.parseInt((String)doc.getMetadata().get("KsDocumentIndex"));
String KsDocumentId = (String)doc.getMetadata().get("KsDocumentId");
int startIdx = currentIdx - preNeighbor;
int endIdx = currentIdx + afterNeighbor;
List<Document> otherDocs = new ArrayList<Document>();
logger.info("Retrieving neighbors for: " + KsDocumentId);
logger.info("Retrieving other documents for neighbor index: " + startIdx + " to " + endIdx);
for(int i = startIdx; i <= endIdx; i++){
if(i == currentIdx){
continue;
}
String queryExpression = "KsDocumentId == '" + KsDocumentId + "' AND KsDocumentIndex == '" + i+"'";
logger.info("Query expression: " + queryExpression);
SearchRequest request = SearchRequest.builder()
.query("*")
.topK(1)
.similarityThreshold(0)
.filterExpression(queryExpression).build();
otherDocs.addAll( this.vectorStore.similaritySearch(request));
}
logger.info("Number of neighbor retrieved : " + otherDocs.size());
return otherDocs;
}
}