```
Nessuna modifica apportata ```
This commit is contained in:
@@ -43,7 +43,10 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
private boolean enableRetrieveOther;
|
||||
private int rank_threshold;
|
||||
private int otherTopK;
|
||||
|
||||
private boolean enableNeighborRetrieve;
|
||||
private int preNeighbor;
|
||||
private int afterNeighbor;
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
@@ -101,7 +104,22 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
}else{
|
||||
this.enableRetrieveOther = false;
|
||||
}
|
||||
|
||||
|
||||
if(this.step.getAttributes().containsKey("enable_neighbor_retrieve")){
|
||||
this.enableNeighborRetrieve = (boolean) this.step.getAttributes().get("enable_neighbor_retrieve");
|
||||
}else{
|
||||
this.enableNeighborRetrieve = false;
|
||||
}
|
||||
if(this.step.getAttributes().containsKey("pre_neighbor")){
|
||||
this.preNeighbor = (int) this.step.getAttributes().get("pre_neighbor");
|
||||
}else{
|
||||
this.preNeighbor = 1;
|
||||
}
|
||||
if(this.step.getAttributes().containsKey("after_neighbor")){
|
||||
this.afterNeighbor = (int) this.step.getAttributes().get("after_neighbor");
|
||||
}else{
|
||||
this.afterNeighbor = 1;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
||||
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
||||
@@ -131,6 +149,12 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
logger.info("enableRanking: " + this.enableRanking);
|
||||
logger.info("rephrasePrompt: " + this.rephrasePrompt);
|
||||
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
|
||||
logger.info("rank_threshold: " + this.rank_threshold);
|
||||
logger.info("otherTopK: " + this.otherTopK);
|
||||
logger.info("enableRetrieveOther: " + this.enableRetrieveOther);
|
||||
logger.info("enableNeighborRetrieve: " + this.enableNeighborRetrieve);
|
||||
logger.info("preNeighbor: " + this.preNeighbor);
|
||||
logger.info("afterNeighbor: " + this.afterNeighbor);
|
||||
}
|
||||
|
||||
|
||||
@@ -210,6 +234,20 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
docs = rankedDocs;
|
||||
}
|
||||
|
||||
if(enableNeighborRetrieve){
|
||||
logger.info("Retrieving neighbor documents: #PRE " + this.preNeighbor + " #POST " + this.afterNeighbor);
|
||||
List<Document> neighborDocs = new ArrayList<Document>();
|
||||
for(Document doc : docs){
|
||||
neighborDocs.addAll( retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor));
|
||||
logger.info("Number of neighbor documents: " + neighborDocs.size());
|
||||
}
|
||||
docs.addAll(neighborDocs);
|
||||
logger.info("Number of documents after neighbor retrieval: " + docs.size());
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
if(enableRetrieveOther){
|
||||
List<Document> otherDocs = new ArrayList<Document>();
|
||||
logger.info("Retrieving other documents");
|
||||
@@ -254,11 +292,26 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
|
||||
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
|
||||
|
||||
//Sort by KsDocumentId and KsDocumentIndex
|
||||
docs.sort((d1, d2) -> {
|
||||
String docId1 = (String) d1.getMetadata().get("KsDocumentId");
|
||||
String docId2 = (String) d2.getMetadata().get("KsDocumentId");
|
||||
int cmp = docId1.compareTo(docId2);
|
||||
if (cmp != 0) {
|
||||
return cmp;
|
||||
} else {
|
||||
int idx1 = Integer.parseInt((String)d1.getMetadata().get("KsDocumentIndex"));
|
||||
int idx2 = Integer.parseInt((String)d2.getMetadata().get("KsDocumentIndex"));
|
||||
return Integer.compare(idx1, idx2);
|
||||
}
|
||||
});
|
||||
|
||||
|
||||
List<String> source_doc = new ArrayList<String>();
|
||||
String resultString = "";
|
||||
|
||||
for(Document doc : docs){
|
||||
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n";
|
||||
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + " - Chunk Index " + doc.getMetadata().get("KsDocumentIndex")+ "\n";
|
||||
resultString += "--------------------------------------------------------------\n";
|
||||
resultString += doc.getText() + "\n";
|
||||
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
|
||||
@@ -274,6 +327,32 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
|
||||
|
||||
private List<Document> retrieveNeighborDocuments(Document doc, int preNeighbor, int afterNeighbor){
|
||||
|
||||
int currentIdx = Integer.parseInt((String)doc.getMetadata().get("KsDocumentIndex"));
|
||||
String KsDocumentId = (String)doc.getMetadata().get("KsDocumentId");
|
||||
|
||||
int startIdx = currentIdx - preNeighbor;
|
||||
int endIdx = currentIdx + afterNeighbor;
|
||||
List<Document> otherDocs = new ArrayList<Document>();
|
||||
logger.info("Retrieving neighbors for: " + KsDocumentId);
|
||||
logger.info("Retrieving other documents for neighbor index: " + startIdx + " to " + endIdx);
|
||||
for(int i = startIdx; i <= endIdx; i++){
|
||||
if(i == currentIdx){
|
||||
continue;
|
||||
}
|
||||
String queryExpression = "KsDocumentId == '" + KsDocumentId + "' AND KsDocumentIndex == '" + i+"'";
|
||||
logger.info("Query expression: " + queryExpression);
|
||||
SearchRequest request = SearchRequest.builder()
|
||||
.query("*")
|
||||
.topK(1)
|
||||
.similarityThreshold(0)
|
||||
.filterExpression(queryExpression).build();
|
||||
|
||||
otherDocs.addAll( this.vectorStore.similaritySearch(request));
|
||||
|
||||
}
|
||||
logger.info("Number of neighbor retrieved : " + otherDocs.size());
|
||||
return otherDocs;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user