From 967b61a0fe6b5c0786404d87e0df06f26a2dbc6e Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Thu, 5 Sep 2024 10:14:22 +0200 Subject: [PATCH 1/2] new engine workflows --- pom.xml | 9 +++--- .../hermione/controllers/TestController.java | 17 +++++++---- .../com/olympus/hermione/models/Scenario.java | 1 + .../hermione/models/ScenarioExecution.java | 3 ++ .../olympus/hermione/models/ScenarioStep.java | 3 ++ .../services/ScenarioExecutionService.java | 30 ++++++++++++++----- .../stepSolvers/BasicAIPromptSolver.java | 6 +++- .../stepSolvers/BasicQueryRagSolver.java | 4 ++- .../stepSolvers/QueryNeo4JSolver.java | 3 ++ .../hermione/stepSolvers/StepSolver.java | 6 ++++ .../stepSolvers/apollo.code-workspace | 16 ---------- src/main/resources/workflow.json | 15 ++++++++++ 12 files changed, 79 insertions(+), 34 deletions(-) delete mode 100644 src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace create mode 100644 src/main/resources/workflow.json diff --git a/pom.xml b/pom.xml index 94371a8..ead9aa4 100644 --- a/pom.xml +++ b/pom.xml @@ -91,10 +91,11 @@ 4.3.4 - com.google.code.gson - gson - 2.8.8 - + com.google.code.gson + gson + 2.8.8 + + diff --git a/src/main/java/com/olympus/hermione/controllers/TestController.java b/src/main/java/com/olympus/hermione/controllers/TestController.java index d534ef1..cc683b6 100644 --- a/src/main/java/com/olympus/hermione/controllers/TestController.java +++ b/src/main/java/com/olympus/hermione/controllers/TestController.java @@ -3,8 +3,12 @@ package com.olympus.hermione.controllers; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; +import com.olympus.hermione.dto.ScenarioExecutionInput; +import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.services.Neo4JUitilityService; import com.olympus.hermione.services.ScenarioExecutionService; @@ -14,20 +18,23 @@ public class TestController { @Autowired ScenarioExecutionService scenarioExecutionService; + + @Autowired Neo4JUitilityService neo4JUitilityService; - @GetMapping("/test/scenario_execution") - public String testScenarioExecution(){ - String result = scenarioExecutionService.executeScenario("66aa162debe80dfcd17f0ef4","How i can change the path where uploaded documents are stored ?"); - return result; + @PostMapping("test/execute") + public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + return scenarioExecutionService.executeScenario(scenarioExecutionInput); } + + @GetMapping("/test/generate_schema_description") public String testGenerateSchemaDescription(){ return neo4JUitilityService.generateSchemaDescription().toString(); } - + } diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java index da997f3..43ce81a 100644 --- a/src/main/java/com/olympus/hermione/models/Scenario.java +++ b/src/main/java/com/olympus/hermione/models/Scenario.java @@ -15,6 +15,7 @@ public class Scenario { private String id; private String name; private String description; + private String startWithStepId; private List steps; private List inputs; diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index dc679e9..2511f3b 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -21,5 +21,8 @@ public class ScenarioExecution { private String scenarioId; + private String currentStepId; + private String nextStepId; + } diff --git a/src/main/java/com/olympus/hermione/models/ScenarioStep.java b/src/main/java/com/olympus/hermione/models/ScenarioStep.java index 4e59efc..a504bfd 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioStep.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioStep.java @@ -12,6 +12,9 @@ public class ScenarioStep { private String description; private String name; private int order; + + private String stepId; + private String nextStepId; private HashMap attributes; diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 64c73ed..5c1f6da 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -71,8 +71,11 @@ public class ScenarioExecutionService { List steps = scenario.getSteps(); steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); + + + for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution, step.getOrder()); + executeScenarioStep(step, scenarioExecution); } return scenarioExecution.getExecSharedMap().get("scenario_output").toString(); @@ -107,9 +110,18 @@ public class ScenarioExecutionService { List steps = scenario.getSteps(); - steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); - for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution, step.getOrder()); + + + String startStepId=scenario.getStartWithStepId(); + + ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); + + executeScenarioStep(startStep, scenarioExecution); + + while (scenarioExecution.getNextStepId()!=null) { + ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); + + executeScenarioStep(step, scenarioExecution); } scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); @@ -130,7 +142,7 @@ public class ScenarioExecutionService { - private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){ + private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){ logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); StepSolver solver=new StepSolver(); switch (step.getType()) { @@ -147,14 +159,18 @@ public class ScenarioExecutionService { break; } - logger.info("Initializing step: " + step.getName()); + logger.info("Initializing step: " + step.getStepId()); solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService); - logger.info("Solving step: " + step.getName()); + logger.info("Solving step: " + step.getStepId()); + scenarioExecution.setCurrentStepId(step.getStepId()); //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); + logger.info("Step solved: " + step.getStepId()); + logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); + scenarioExecutionRepository.save(scenarioExecutionNew); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index 856cc68..137471a 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -49,6 +49,8 @@ public class BasicAIPromptSolver extends StepSolver { System.out.println("Solving step: " + this.step.getName()); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); @@ -64,7 +66,9 @@ public class BasicAIPromptSolver extends StepSolver { String output = response.get(0).getOutput().getContent(); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); - + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java index 12a6423..00bf6ed 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java @@ -93,7 +93,9 @@ public class BasicQueryRagSolver extends StepSolver { //concatenate the content of the results into a single string String resultString = String.join("\n", result); this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString); - + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java index 66d6a12..f267760 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java @@ -58,6 +58,9 @@ public class QueryNeo4JSolver extends StepSolver { this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage()); } + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java index 0360de4..40f82c3 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java @@ -20,8 +20,14 @@ public class StepSolver { public ScenarioExecution solveStep(){ System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; }; + + public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){ this.scenarioExecution = scenarioExecution; this.step = step; diff --git a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace b/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace deleted file mode 100644 index 4378d1e..0000000 --- a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace +++ /dev/null @@ -1,16 +0,0 @@ -{ - "folders": [ - { - "path": "../../../../../../../../apollo" - }, - { - "name": "hermione", - "path": "../../../../../../.." - }, - { - "name": "hermione-fe", - "path": "../../../../../../../../hermione-fe" - } - ], - "settings": {} -} \ No newline at end of file diff --git a/src/main/resources/workflow.json b/src/main/resources/workflow.json new file mode 100644 index 0000000..4d77fb8 --- /dev/null +++ b/src/main/resources/workflow.json @@ -0,0 +1,15 @@ +{ + "id": "test-workflow", + "version": 1, + "steps": [ + { + "id": "step1", + "stepType": "com.olympus.hermione.stepSolversV2.Task1", + "nextStepId": "step2" + }, + { + "id": "step2", + "stepType": "com.olympus.hermione.stepSolversV2.Task2" + } + ] + } \ No newline at end of file From 427ebf0072862f45db0fa005428709d48dc20c59 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Fri, 27 Sep 2024 09:24:04 +0200 Subject: [PATCH 2/2] refactor: Enable asynchronous execution in HermioneApplication This commit enables asynchronous execution in HermioneApplication by adding the @EnableAsync annotation. This allows for improved performance and responsiveness in the application. --- .../olympus/hermione/HermioneApplication.java | 2 + .../controllers/ExecutionController.java | 25 +++++ .../controllers/ScenarioController.java | 17 ++- .../hermione/models/ScenarioExecution.java | 5 +- .../services/ScenarioExecutionService.java | 103 +++++++++++++----- 5 files changed, 122 insertions(+), 30 deletions(-) create mode 100644 src/main/java/com/olympus/hermione/controllers/ExecutionController.java diff --git a/src/main/java/com/olympus/hermione/HermioneApplication.java b/src/main/java/com/olympus/hermione/HermioneApplication.java index 87fc3c6..25f53ef 100644 --- a/src/main/java/com/olympus/hermione/HermioneApplication.java +++ b/src/main/java/com/olympus/hermione/HermioneApplication.java @@ -2,8 +2,10 @@ package com.olympus.hermione; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.scheduling.annotation.EnableAsync; @SpringBootApplication +@EnableAsync public class HermioneApplication { public static void main(String[] args) { diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java new file mode 100644 index 0000000..a20746e --- /dev/null +++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java @@ -0,0 +1,25 @@ +package com.olympus.hermione.controllers; + +import java.util.Iterator; +import java.util.List; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.repository.ScenarioExecutionRepository; + +@RestController +public class ExecutionController { + @Autowired + ScenarioExecutionRepository scenarioExecutionRepository; + + @GetMapping("/execution") + public ScenarioExecution getOldExections(@RequestParam String id){ + + + return scenarioExecutionRepository.findById(id).get(); + } +} diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java index b68f02a..3d90a40 100644 --- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java +++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java @@ -39,10 +39,25 @@ public class ScenarioController { } @PostMapping("scenarios/execute") - public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { return scenarioExecutionService.executeScenario(scenarioExecutionInput); } + @PostMapping("scenarios/execute-async") + public ScenarioOutput executeScenarioAsync(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput); + + scenarioExecutionService.executeScenarioAsync(preOutput); + + return preOutput; + } + + @GetMapping("/scenarios/getExecutionProgress/{id}") + public ScenarioOutput getExecutionProgress(@PathVariable String id) { + return scenarioExecutionService.getExecutionProgress(id); + } + + @GetMapping("/scenarios/execute/{id}") public ScenarioExecution getScenarioExecution(@PathVariable String id) { return scenarioExecutionRepository.findById(id).get(); diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index 2511f3b..eb18abf 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -5,6 +5,8 @@ import java.util.HashMap; import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.mapping.Document; +import com.olympus.hermione.dto.ScenarioExecutionInput; + import lombok.Getter; import lombok.Setter; @@ -24,5 +26,6 @@ public class ScenarioExecution { private String currentStepId; private String nextStepId; - + + private ScenarioExecutionInput scenarioExecutionInput; } diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 5c1f6da..e5b2fdd 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import com.olympus.hermione.dto.ScenarioExecutionInput; @@ -50,44 +51,97 @@ public class ScenarioExecutionService { private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); + public ScenarioOutput getExecutionProgress(String scenarioExecutionId){ + ScenarioOutput scenarioOutput = new ScenarioOutput(); + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId); + + if(o_scenarioExecution.isPresent()){ + ScenarioExecution scenarioExecution = o_scenarioExecution.get(); + + if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){ + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("OK"); + scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); + }else{ + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("IN_PROGRESS"); + } + + }else{ + scenarioOutput.setScenarioExecution_id(null); + scenarioOutput.setStatus("ERROR"); + scenarioOutput.setMessage("Scenario Execution not found"); + } + + return scenarioOutput; + } + + @Async + public void executeScenarioAsync(ScenarioOutput scenarioOutput){ + + String scenarioExecutionID = scenarioOutput.getScenarioExecution_id(); + + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID); + + if(o_scenarioExecution.isPresent()){ + ScenarioExecution scenarioExecution = o_scenarioExecution.get(); + + HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs(); + + Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId()); + + Scenario scenario = o_scenario.get(); + logger.info("Start execution of scenario: " + scenario.getName()); + + HashMap execSharedMap = new HashMap(); + execSharedMap.put("user_input", inputs); + scenarioExecution.setExecSharedMap(execSharedMap); + scenarioExecutionRepository.save(scenarioExecution); + + List steps = scenario.getSteps(); + String startStepId=scenario.getStartWithStepId(); + + ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); + + executeScenarioStep(startStep, scenarioExecution); + + while (scenarioExecution.getNextStepId()!=null) { + ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); + executeScenarioStep(step, scenarioExecution); + } + + + } + } - public String executeScenario(String scenarioId, String input){ + public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){ + + String scenarioId = scenarioExecutionInput.getScenario_id(); + HashMap inputs = scenarioExecutionInput.getInputs(); + ScenarioOutput scenarioOutput = new ScenarioOutput(); Optional o_scenario = scenarioRepository.findById(scenarioId); if(o_scenario.isPresent()){ Scenario scenario = o_scenario.get(); - logger.info("Executing scenario: " + scenario.getName()); + logger.info("Executing scenario: " + scenario.getName()); + ScenarioExecution scenarioExecution = new ScenarioExecution(); scenarioExecution.setScenario(scenario); - - HashMap execSharedMap = new HashMap(); - execSharedMap.put("user_input", input); - scenarioExecution.setExecSharedMap(execSharedMap); + scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput); scenarioExecutionRepository.save(scenarioExecution); - - List steps = scenario.getSteps(); - steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("STARTED"); - - - for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution); - } - - return scenarioExecution.getExecSharedMap().get("scenario_output").toString(); - - }else{ - logger.error("Scenario not found with id: " + scenarioId); } - ; - return ""; + return scenarioOutput; } + public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){ String scenarioId = scenarioExecutionInput.getScenario_id(); @@ -109,9 +163,6 @@ public class ScenarioExecutionService { scenarioExecutionRepository.save(scenarioExecution); List steps = scenario.getSteps(); - - - String startStepId=scenario.getStartWithStepId(); ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); @@ -120,7 +171,6 @@ public class ScenarioExecutionService { while (scenarioExecution.getNextStepId()!=null) { ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); - executeScenarioStep(step, scenarioExecution); } @@ -165,15 +215,12 @@ public class ScenarioExecutionService { logger.info("Solving step: " + step.getStepId()); scenarioExecution.setCurrentStepId(step.getStepId()); - //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); logger.info("Step solved: " + step.getStepId()); logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); scenarioExecutionRepository.save(scenarioExecutionNew); - - }