From 37caef5c9db9839a1158c5403e0139b88c1e238f Mon Sep 17 00:00:00 2001 From: Emanuele Ferrelli Date: Fri, 21 Feb 2025 10:01:18 +0100 Subject: [PATCH] Develop Q&A on Doc scenario --- .../services/ScenarioExecutionService.java | 8 ++ .../stepSolvers/DeleteDocTempSolver.java | 75 +++++++++++++ .../stepSolvers/EmbeddingDocTempSolver.java | 106 ++++++++++++++++++ 3 files changed, 189 insertions(+) create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java create mode 100644 src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 412f409..4644c1f 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -49,6 +49,8 @@ import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver; import com.olympus.hermione.stepSolvers.SummarizeDocSolver; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; +import com.olympus.hermione.stepSolvers.DeleteDocTempSolver; +import com.olympus.hermione.stepSolvers.EmbeddingDocTempSolver; import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; import com.olympus.hermione.stepSolvers.SourceCodeRagSolver; import com.olympus.hermione.stepSolvers.StepSolver; @@ -267,6 +269,12 @@ public class ScenarioExecutionService { case "OLYMPUS_QUERY_AI": solver = new OlynmpusChatClientSolver(); break; + case "EMBED_TEMPORARY_DOC": + solver = new EmbeddingDocTempSolver(); + break; + case "DELETE_TEMPORARY_DOC": + solver = new DeleteDocTempSolver(); + break; default: break; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java new file mode 100644 index 0000000..266e82f --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java @@ -0,0 +1,75 @@ +package com.olympus.hermione.stepSolvers; + +import ch.qos.logback.classic.Logger; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; +import java.util.List; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; + +public class DeleteDocTempSolver extends StepSolver { + + private String rag_filter; + private int topk; + private double threshold; + private String query; + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + this.scenarioExecution.getExecSharedMap().put("scenario_execution_id", this.scenarioExecution.getId()); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + // this.scenario_execution_id = attributeParser.parse((String) this.scenarioExecution.getId()); + + this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter")); + + if(this.step.getAttributes().containsKey("rag_query")){ + this.query = (String) this.step.getAttributes().get("rag_query"); + }else{ + this.query = "*"; + } + + if(this.step.getAttributes().containsKey("rag_topk")){ + this.topk = (int) this.step.getAttributes().get("rag_topk"); + }else{ + this.topk = 1000; + } + + if(this.step.getAttributes().containsKey("rag_threshold")){ + this.threshold = (double) this.step.getAttributes().get("rag_threshold"); + }else{ + this.threshold = 0.0; + } + + } + + @Override + public ScenarioExecution solveStep(){ + + System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + SearchRequest searchRequest = SearchRequest.defaults() + .withQuery(this.query) + .withTopK(this.topk) + .withSimilarityThreshold(this.threshold) + .withFilterExpression(this.rag_filter); + + + List docs = vectorStore.similaritySearch(searchRequest); + List ids = docs.stream().map(Document::getId).toList(); + vectorStore.delete(ids); + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +} diff --git a/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java new file mode 100644 index 0000000..6b13974 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/EmbeddingDocTempSolver.java @@ -0,0 +1,106 @@ +package com.olympus.hermione.stepSolvers; + +import ch.qos.logback.classic.Logger; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; +import java.io.File; +import java.util.Collections; +import java.util.List; +import org.apache.tika.Tika; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.transformer.splitter.TokenTextSplitter; + +public class EmbeddingDocTempSolver extends StepSolver { + + private String scenario_execution_id; + private String path_file; + private int default_chunk_size; + private int min_chunk_size; + private int min_chunk_length_to_embed; + private int max_num_chunks; + + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + this.scenarioExecution.getExecSharedMap().put("scenario_execution_id", this.scenarioExecution.getId()); + logger.info("Scenario Execution ID: "+this.scenarioExecution.getId()); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.scenario_execution_id = attributeParser.parse((String) this.scenarioExecution.getId()); + + this.path_file = attributeParser.parse((String) this.step.getAttributes().get("path_file")); + + if(this.step.getAttributes().containsKey("default_chunk_size")){ + this.default_chunk_size = (int) this.step.getAttributes().get("default_chunk_size"); + }else{ + this.default_chunk_size = 8000; + } + + if(this.step.getAttributes().containsKey("min_chunk_size")){ + this.min_chunk_size = (int) this.step.getAttributes().get("min_chunk_size"); + }else{ + this.min_chunk_size = 50; + } + + if(this.step.getAttributes().containsKey("min_chunk_length_to_embed")){ + this.min_chunk_length_to_embed = (int) this.step.getAttributes().get("min_chunk_length_to_embed"); + }else{ + this.min_chunk_length_to_embed = 50; + } + + if(this.step.getAttributes().containsKey("max_num_chunks")){ + this.max_num_chunks = (int) this.step.getAttributes().get("max_num_chunks"); + }else{ + this.max_num_chunks = 1000; + } + + } + + @Override + public ScenarioExecution solveStep(){ + + try{ + logger.info("Solving step: " + this.step.getName()); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); + + logger.info("Embedding documents"); + File file = new File(this.path_file); + Tika tika = new Tika(); + tika.setMaxStringLength(-1); + String text = tika.parseToString(file); + Document myDoc = new Document(text); + + List docs = Collections.singletonList(myDoc); + + TokenTextSplitter splitter = new TokenTextSplitter(this.default_chunk_size, + this.min_chunk_size, + this.min_chunk_length_to_embed, + this.max_num_chunks, + true); + + docs.forEach(doc -> { + List splitDocs = splitter.split(doc); + logger.info("Number of documents: " + splitDocs.size()); + + splitDocs.forEach(splitDoc -> { + splitDoc.getMetadata().put("KsScenarioExecutionId", this.scenario_execution_id); + }); + vectorStore.add(splitDocs); + }); + }catch (Exception e){ + logger.error("Error while solvingStep: "+e.getMessage()); + e.printStackTrace(); + } + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +}