diff --git a/pom.xml b/pom.xml index 94371a8..ead9aa4 100644 --- a/pom.xml +++ b/pom.xml @@ -91,10 +91,11 @@ 4.3.4 - com.google.code.gson - gson - 2.8.8 - + com.google.code.gson + gson + 2.8.8 + + diff --git a/src/main/java/com/olympus/hermione/controllers/TestController.java b/src/main/java/com/olympus/hermione/controllers/TestController.java index d534ef1..cc683b6 100644 --- a/src/main/java/com/olympus/hermione/controllers/TestController.java +++ b/src/main/java/com/olympus/hermione/controllers/TestController.java @@ -3,8 +3,12 @@ package com.olympus.hermione.controllers; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; +import com.olympus.hermione.dto.ScenarioExecutionInput; +import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.services.Neo4JUitilityService; import com.olympus.hermione.services.ScenarioExecutionService; @@ -14,20 +18,23 @@ public class TestController { @Autowired ScenarioExecutionService scenarioExecutionService; + + @Autowired Neo4JUitilityService neo4JUitilityService; - @GetMapping("/test/scenario_execution") - public String testScenarioExecution(){ - String result = scenarioExecutionService.executeScenario("66aa162debe80dfcd17f0ef4","How i can change the path where uploaded documents are stored ?"); - return result; + @PostMapping("test/execute") + public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + return scenarioExecutionService.executeScenario(scenarioExecutionInput); } + + @GetMapping("/test/generate_schema_description") public String testGenerateSchemaDescription(){ return neo4JUitilityService.generateSchemaDescription().toString(); } - + } diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java index da997f3..43ce81a 100644 --- a/src/main/java/com/olympus/hermione/models/Scenario.java +++ b/src/main/java/com/olympus/hermione/models/Scenario.java @@ -15,6 +15,7 @@ public class Scenario { private String id; private String name; private String description; + private String startWithStepId; private List steps; private List inputs; diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index dc679e9..2511f3b 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -21,5 +21,8 @@ public class ScenarioExecution { private String scenarioId; + private String currentStepId; + private String nextStepId; + } diff --git a/src/main/java/com/olympus/hermione/models/ScenarioStep.java b/src/main/java/com/olympus/hermione/models/ScenarioStep.java index 4e59efc..a504bfd 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioStep.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioStep.java @@ -12,6 +12,9 @@ public class ScenarioStep { private String description; private String name; private int order; + + private String stepId; + private String nextStepId; private HashMap attributes; diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 64c73ed..5c1f6da 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -71,8 +71,11 @@ public class ScenarioExecutionService { List steps = scenario.getSteps(); steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); + + + for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution, step.getOrder()); + executeScenarioStep(step, scenarioExecution); } return scenarioExecution.getExecSharedMap().get("scenario_output").toString(); @@ -107,9 +110,18 @@ public class ScenarioExecutionService { List steps = scenario.getSteps(); - steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); - for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution, step.getOrder()); + + + String startStepId=scenario.getStartWithStepId(); + + ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); + + executeScenarioStep(startStep, scenarioExecution); + + while (scenarioExecution.getNextStepId()!=null) { + ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); + + executeScenarioStep(step, scenarioExecution); } scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); @@ -130,7 +142,7 @@ public class ScenarioExecutionService { - private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){ + private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){ logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); StepSolver solver=new StepSolver(); switch (step.getType()) { @@ -147,14 +159,18 @@ public class ScenarioExecutionService { break; } - logger.info("Initializing step: " + step.getName()); + logger.info("Initializing step: " + step.getStepId()); solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService); - logger.info("Solving step: " + step.getName()); + logger.info("Solving step: " + step.getStepId()); + scenarioExecution.setCurrentStepId(step.getStepId()); //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); + logger.info("Step solved: " + step.getStepId()); + logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); + scenarioExecutionRepository.save(scenarioExecutionNew); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index 856cc68..137471a 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -49,6 +49,8 @@ public class BasicAIPromptSolver extends StepSolver { System.out.println("Solving step: " + this.step.getName()); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); @@ -64,7 +66,9 @@ public class BasicAIPromptSolver extends StepSolver { String output = response.get(0).getOutput().getContent(); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); - + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java index 12a6423..00bf6ed 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java @@ -93,7 +93,9 @@ public class BasicQueryRagSolver extends StepSolver { //concatenate the content of the results into a single string String resultString = String.join("\n", result); this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString); - + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java index 66d6a12..f267760 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java @@ -58,6 +58,9 @@ public class QueryNeo4JSolver extends StepSolver { this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage()); } + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java index 0360de4..40f82c3 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java @@ -20,8 +20,14 @@ public class StepSolver { public ScenarioExecution solveStep(){ System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; }; + + public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){ this.scenarioExecution = scenarioExecution; this.step = step; diff --git a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace b/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace deleted file mode 100644 index 4378d1e..0000000 --- a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace +++ /dev/null @@ -1,16 +0,0 @@ -{ - "folders": [ - { - "path": "../../../../../../../../apollo" - }, - { - "name": "hermione", - "path": "../../../../../../.." - }, - { - "name": "hermione-fe", - "path": "../../../../../../../../hermione-fe" - } - ], - "settings": {} -} \ No newline at end of file diff --git a/src/main/resources/workflow.json b/src/main/resources/workflow.json new file mode 100644 index 0000000..4d77fb8 --- /dev/null +++ b/src/main/resources/workflow.json @@ -0,0 +1,15 @@ +{ + "id": "test-workflow", + "version": 1, + "steps": [ + { + "id": "step1", + "stepType": "com.olympus.hermione.stepSolversV2.Task1", + "nextStepId": "step2" + }, + { + "id": "step2", + "stepType": "com.olympus.hermione.stepSolversV2.Task2" + } + ] + } \ No newline at end of file