```
Nessuna modifica apportata ```
This commit is contained in:
@@ -43,6 +43,9 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
private boolean enableRetrieveOther;
|
private boolean enableRetrieveOther;
|
||||||
private int rank_threshold;
|
private int rank_threshold;
|
||||||
private int otherTopK;
|
private int otherTopK;
|
||||||
|
private boolean enableNeighborRetrieve;
|
||||||
|
private int preNeighbor;
|
||||||
|
private int afterNeighbor;
|
||||||
|
|
||||||
private void loadParameters(){
|
private void loadParameters(){
|
||||||
logger.info("Loading parameters");
|
logger.info("Loading parameters");
|
||||||
@@ -102,6 +105,21 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
this.enableRetrieveOther = false;
|
this.enableRetrieveOther = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(this.step.getAttributes().containsKey("enable_neighbor_retrieve")){
|
||||||
|
this.enableNeighborRetrieve = (boolean) this.step.getAttributes().get("enable_neighbor_retrieve");
|
||||||
|
}else{
|
||||||
|
this.enableNeighborRetrieve = false;
|
||||||
|
}
|
||||||
|
if(this.step.getAttributes().containsKey("pre_neighbor")){
|
||||||
|
this.preNeighbor = (int) this.step.getAttributes().get("pre_neighbor");
|
||||||
|
}else{
|
||||||
|
this.preNeighbor = 1;
|
||||||
|
}
|
||||||
|
if(this.step.getAttributes().containsKey("after_neighbor")){
|
||||||
|
this.afterNeighbor = (int) this.step.getAttributes().get("after_neighbor");
|
||||||
|
}else{
|
||||||
|
this.afterNeighbor = 1;
|
||||||
|
}
|
||||||
|
|
||||||
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
if(this.step.getAttributes().containsKey("rephrase_prompt")){
|
||||||
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
this.rephrasePrompt = attributeParser.parse((String)this.step.getAttributes().get("rephrase_prompt"));
|
||||||
@@ -131,6 +149,12 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
logger.info("enableRanking: " + this.enableRanking);
|
logger.info("enableRanking: " + this.enableRanking);
|
||||||
logger.info("rephrasePrompt: " + this.rephrasePrompt);
|
logger.info("rephrasePrompt: " + this.rephrasePrompt);
|
||||||
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
|
logger.info("MultiPhrasePrompt: " + this.MultiPhrasePrompt);
|
||||||
|
logger.info("rank_threshold: " + this.rank_threshold);
|
||||||
|
logger.info("otherTopK: " + this.otherTopK);
|
||||||
|
logger.info("enableRetrieveOther: " + this.enableRetrieveOther);
|
||||||
|
logger.info("enableNeighborRetrieve: " + this.enableNeighborRetrieve);
|
||||||
|
logger.info("preNeighbor: " + this.preNeighbor);
|
||||||
|
logger.info("afterNeighbor: " + this.afterNeighbor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -210,6 +234,20 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
docs = rankedDocs;
|
docs = rankedDocs;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if(enableNeighborRetrieve){
|
||||||
|
logger.info("Retrieving neighbor documents: #PRE " + this.preNeighbor + " #POST " + this.afterNeighbor);
|
||||||
|
List<Document> neighborDocs = new ArrayList<Document>();
|
||||||
|
for(Document doc : docs){
|
||||||
|
neighborDocs.addAll( retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor));
|
||||||
|
logger.info("Number of neighbor documents: " + neighborDocs.size());
|
||||||
|
}
|
||||||
|
docs.addAll(neighborDocs);
|
||||||
|
logger.info("Number of documents after neighbor retrieval: " + docs.size());
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if(enableRetrieveOther){
|
if(enableRetrieveOther){
|
||||||
List<Document> otherDocs = new ArrayList<Document>();
|
List<Document> otherDocs = new ArrayList<Document>();
|
||||||
logger.info("Retrieving other documents");
|
logger.info("Retrieving other documents");
|
||||||
@@ -254,11 +292,26 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
|
|
||||||
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
|
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
|
||||||
|
|
||||||
|
//Sort by KsDocumentId and KsDocumentIndex
|
||||||
|
docs.sort((d1, d2) -> {
|
||||||
|
String docId1 = (String) d1.getMetadata().get("KsDocumentId");
|
||||||
|
String docId2 = (String) d2.getMetadata().get("KsDocumentId");
|
||||||
|
int cmp = docId1.compareTo(docId2);
|
||||||
|
if (cmp != 0) {
|
||||||
|
return cmp;
|
||||||
|
} else {
|
||||||
|
int idx1 = Integer.parseInt((String)d1.getMetadata().get("KsDocumentIndex"));
|
||||||
|
int idx2 = Integer.parseInt((String)d2.getMetadata().get("KsDocumentIndex"));
|
||||||
|
return Integer.compare(idx1, idx2);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
List<String> source_doc = new ArrayList<String>();
|
List<String> source_doc = new ArrayList<String>();
|
||||||
String resultString = "";
|
String resultString = "";
|
||||||
|
|
||||||
for(Document doc : docs){
|
for(Document doc : docs){
|
||||||
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + "\n";
|
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + " - Chunk Index " + doc.getMetadata().get("KsDocumentIndex")+ "\n";
|
||||||
resultString += "--------------------------------------------------------------\n";
|
resultString += "--------------------------------------------------------------\n";
|
||||||
resultString += doc.getText() + "\n";
|
resultString += doc.getText() + "\n";
|
||||||
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
|
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
|
||||||
@@ -274,6 +327,32 @@ public class AdvancedQueryRagSolver extends StepSolver {
|
|||||||
return this.scenarioExecution;
|
return this.scenarioExecution;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private List<Document> retrieveNeighborDocuments(Document doc, int preNeighbor, int afterNeighbor){
|
||||||
|
|
||||||
|
int currentIdx = Integer.parseInt((String)doc.getMetadata().get("KsDocumentIndex"));
|
||||||
|
String KsDocumentId = (String)doc.getMetadata().get("KsDocumentId");
|
||||||
|
|
||||||
|
int startIdx = currentIdx - preNeighbor;
|
||||||
|
int endIdx = currentIdx + afterNeighbor;
|
||||||
|
List<Document> otherDocs = new ArrayList<Document>();
|
||||||
|
logger.info("Retrieving neighbors for: " + KsDocumentId);
|
||||||
|
logger.info("Retrieving other documents for neighbor index: " + startIdx + " to " + endIdx);
|
||||||
|
for(int i = startIdx; i <= endIdx; i++){
|
||||||
|
if(i == currentIdx){
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
String queryExpression = "KsDocumentId == '" + KsDocumentId + "' AND KsDocumentIndex == '" + i+"'";
|
||||||
|
logger.info("Query expression: " + queryExpression);
|
||||||
|
SearchRequest request = SearchRequest.builder()
|
||||||
|
.query("*")
|
||||||
|
.topK(1)
|
||||||
|
.similarityThreshold(0)
|
||||||
|
.filterExpression(queryExpression).build();
|
||||||
|
|
||||||
|
otherDocs.addAll( this.vectorStore.similaritySearch(request));
|
||||||
|
|
||||||
|
}
|
||||||
|
logger.info("Number of neighbor retrieved : " + otherDocs.size());
|
||||||
|
return otherDocs;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user