diff --git a/pom.xml b/pom.xml
index 94371a8..ead9aa4 100644
--- a/pom.xml
+++ b/pom.xml
@@ -91,10 +91,11 @@
4.3.4
- com.google.code.gson
- gson
- 2.8.8
-
+ com.google.code.gson
+ gson
+ 2.8.8
+
+
diff --git a/src/main/java/com/olympus/hermione/HermioneApplication.java b/src/main/java/com/olympus/hermione/HermioneApplication.java
index 87fc3c6..25f53ef 100644
--- a/src/main/java/com/olympus/hermione/HermioneApplication.java
+++ b/src/main/java/com/olympus/hermione/HermioneApplication.java
@@ -2,8 +2,10 @@ package com.olympus.hermione;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
+import org.springframework.scheduling.annotation.EnableAsync;
@SpringBootApplication
+@EnableAsync
public class HermioneApplication {
public static void main(String[] args) {
diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java
new file mode 100644
index 0000000..a20746e
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java
@@ -0,0 +1,25 @@
+package com.olympus.hermione.controllers;
+
+import java.util.Iterator;
+import java.util.List;
+
+import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
+
+import com.olympus.hermione.models.ScenarioExecution;
+import com.olympus.hermione.repository.ScenarioExecutionRepository;
+
+@RestController
+public class ExecutionController {
+ @Autowired
+ ScenarioExecutionRepository scenarioExecutionRepository;
+
+ @GetMapping("/execution")
+ public ScenarioExecution getOldExections(@RequestParam String id){
+
+
+ return scenarioExecutionRepository.findById(id).get();
+ }
+}
diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
index b68f02a..3d90a40 100644
--- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
+++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
@@ -39,10 +39,25 @@ public class ScenarioController {
}
@PostMapping("scenarios/execute")
- public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
+ public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
return scenarioExecutionService.executeScenario(scenarioExecutionInput);
}
+ @PostMapping("scenarios/execute-async")
+ public ScenarioOutput executeScenarioAsync(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
+ ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput);
+
+ scenarioExecutionService.executeScenarioAsync(preOutput);
+
+ return preOutput;
+ }
+
+ @GetMapping("/scenarios/getExecutionProgress/{id}")
+ public ScenarioOutput getExecutionProgress(@PathVariable String id) {
+ return scenarioExecutionService.getExecutionProgress(id);
+ }
+
+
@GetMapping("/scenarios/execute/{id}")
public ScenarioExecution getScenarioExecution(@PathVariable String id) {
return scenarioExecutionRepository.findById(id).get();
diff --git a/src/main/java/com/olympus/hermione/controllers/TestController.java b/src/main/java/com/olympus/hermione/controllers/TestController.java
index d534ef1..cc683b6 100644
--- a/src/main/java/com/olympus/hermione/controllers/TestController.java
+++ b/src/main/java/com/olympus/hermione/controllers/TestController.java
@@ -3,8 +3,12 @@ package com.olympus.hermione.controllers;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne;
import org.springframework.web.bind.annotation.GetMapping;
+import org.springframework.web.bind.annotation.PostMapping;
+import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
+import com.olympus.hermione.dto.ScenarioExecutionInput;
+import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.services.Neo4JUitilityService;
import com.olympus.hermione.services.ScenarioExecutionService;
@@ -14,20 +18,23 @@ public class TestController {
@Autowired
ScenarioExecutionService scenarioExecutionService;
+
+
@Autowired
Neo4JUitilityService neo4JUitilityService;
- @GetMapping("/test/scenario_execution")
- public String testScenarioExecution(){
- String result = scenarioExecutionService.executeScenario("66aa162debe80dfcd17f0ef4","How i can change the path where uploaded documents are stored ?");
- return result;
+ @PostMapping("test/execute")
+ public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
+ return scenarioExecutionService.executeScenario(scenarioExecutionInput);
}
+
+
@GetMapping("/test/generate_schema_description")
public String testGenerateSchemaDescription(){
return neo4JUitilityService.generateSchemaDescription().toString();
}
-
+
}
diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java
index da997f3..43ce81a 100644
--- a/src/main/java/com/olympus/hermione/models/Scenario.java
+++ b/src/main/java/com/olympus/hermione/models/Scenario.java
@@ -15,6 +15,7 @@ public class Scenario {
private String id;
private String name;
private String description;
+ private String startWithStepId;
private List steps;
private List inputs;
diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
index dc679e9..eb18abf 100644
--- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
+++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
@@ -5,6 +5,8 @@ import java.util.HashMap;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;
+import com.olympus.hermione.dto.ScenarioExecutionInput;
+
import lombok.Getter;
import lombok.Setter;
@@ -21,5 +23,9 @@ public class ScenarioExecution {
private String scenarioId;
+ private String currentStepId;
+ private String nextStepId;
+
+ private ScenarioExecutionInput scenarioExecutionInput;
}
diff --git a/src/main/java/com/olympus/hermione/models/ScenarioStep.java b/src/main/java/com/olympus/hermione/models/ScenarioStep.java
index 4e59efc..a504bfd 100644
--- a/src/main/java/com/olympus/hermione/models/ScenarioStep.java
+++ b/src/main/java/com/olympus/hermione/models/ScenarioStep.java
@@ -12,6 +12,9 @@ public class ScenarioStep {
private String description;
private String name;
private int order;
+
+ private String stepId;
+ private String nextStepId;
private HashMap attributes;
diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
index 64c73ed..e5b2fdd 100644
--- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
+++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
@@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
import com.olympus.hermione.dto.ScenarioExecutionInput;
@@ -50,41 +51,97 @@ public class ScenarioExecutionService {
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
+ public ScenarioOutput getExecutionProgress(String scenarioExecutionId){
+ ScenarioOutput scenarioOutput = new ScenarioOutput();
+ Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId);
+
+ if(o_scenarioExecution.isPresent()){
+ ScenarioExecution scenarioExecution = o_scenarioExecution.get();
+
+ if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){
+ scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
+ scenarioOutput.setStatus("OK");
+ scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
+ }else{
+ scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
+ scenarioOutput.setStatus("IN_PROGRESS");
+ }
+
+ }else{
+ scenarioOutput.setScenarioExecution_id(null);
+ scenarioOutput.setStatus("ERROR");
+ scenarioOutput.setMessage("Scenario Execution not found");
+ }
+
+ return scenarioOutput;
+ }
+
+ @Async
+ public void executeScenarioAsync(ScenarioOutput scenarioOutput){
+
+ String scenarioExecutionID = scenarioOutput.getScenarioExecution_id();
+
+ Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID);
+
+ if(o_scenarioExecution.isPresent()){
+ ScenarioExecution scenarioExecution = o_scenarioExecution.get();
+
+ HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs();
+
+ Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId());
+
+ Scenario scenario = o_scenario.get();
+ logger.info("Start execution of scenario: " + scenario.getName());
+
+ HashMap execSharedMap = new HashMap();
+ execSharedMap.put("user_input", inputs);
+ scenarioExecution.setExecSharedMap(execSharedMap);
+ scenarioExecutionRepository.save(scenarioExecution);
+
+ List steps = scenario.getSteps();
+ String startStepId=scenario.getStartWithStepId();
+
+ ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
+
+ executeScenarioStep(startStep, scenarioExecution);
+
+ while (scenarioExecution.getNextStepId()!=null) {
+ ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
+ executeScenarioStep(step, scenarioExecution);
+ }
+
+
+ }
+ }
- public String executeScenario(String scenarioId, String input){
+ public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){
+
+ String scenarioId = scenarioExecutionInput.getScenario_id();
+ HashMap inputs = scenarioExecutionInput.getInputs();
+ ScenarioOutput scenarioOutput = new ScenarioOutput();
Optional o_scenario = scenarioRepository.findById(scenarioId);
if(o_scenario.isPresent()){
Scenario scenario = o_scenario.get();
- logger.info("Executing scenario: " + scenario.getName());
+ logger.info("Executing scenario: " + scenario.getName());
+
ScenarioExecution scenarioExecution = new ScenarioExecution();
scenarioExecution.setScenario(scenario);
-
- HashMap execSharedMap = new HashMap();
- execSharedMap.put("user_input", input);
- scenarioExecution.setExecSharedMap(execSharedMap);
+ scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
scenarioExecutionRepository.save(scenarioExecution);
-
- List steps = scenario.getSteps();
- steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
- for (ScenarioStep step : steps) {
- executeScenarioStep(step, scenarioExecution, step.getOrder());
- }
+ scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
+ scenarioOutput.setStatus("STARTED");
- return scenarioExecution.getExecSharedMap().get("scenario_output").toString();
-
- }else{
- logger.error("Scenario not found with id: " + scenarioId);
}
- ;
- return "";
+ return scenarioOutput;
}
+
public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){
String scenarioId = scenarioExecutionInput.getScenario_id();
@@ -106,10 +163,15 @@ public class ScenarioExecutionService {
scenarioExecutionRepository.save(scenarioExecution);
List steps = scenario.getSteps();
+ String startStepId=scenario.getStartWithStepId();
- steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
- for (ScenarioStep step : steps) {
- executeScenarioStep(step, scenarioExecution, step.getOrder());
+ ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
+
+ executeScenarioStep(startStep, scenarioExecution);
+
+ while (scenarioExecution.getNextStepId()!=null) {
+ ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
+ executeScenarioStep(step, scenarioExecution);
}
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
@@ -130,7 +192,7 @@ public class ScenarioExecutionService {
- private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
+ private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
StepSolver solver=new StepSolver();
switch (step.getType()) {
@@ -147,17 +209,18 @@ public class ScenarioExecutionService {
break;
}
- logger.info("Initializing step: " + step.getName());
+ logger.info("Initializing step: " + step.getStepId());
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
- logger.info("Solving step: " + step.getName());
+ logger.info("Solving step: " + step.getStepId());
+ scenarioExecution.setCurrentStepId(step.getStepId());
- //must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep();
- scenarioExecutionRepository.save(scenarioExecutionNew);
-
+ logger.info("Step solved: " + step.getStepId());
+ logger.info("Next step: " + scenarioExecutionNew.getNextStepId());
+ scenarioExecutionRepository.save(scenarioExecutionNew);
}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
index 856cc68..137471a 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
@@ -49,6 +49,8 @@ public class BasicAIPromptSolver extends StepSolver {
System.out.println("Solving step: " + this.step.getName());
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+
loadParameters();
@@ -64,7 +66,9 @@ public class BasicAIPromptSolver extends StepSolver {
String output = response.get(0).getOutput().getContent();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
-
+
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+
return this.scenarioExecution;
}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
index 12a6423..00bf6ed 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
@@ -93,7 +93,9 @@ public class BasicQueryRagSolver extends StepSolver {
//concatenate the content of the results into a single string
String resultString = String.join("\n", result);
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
-
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+
return this.scenarioExecution;
}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java
index 66d6a12..f267760 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java
@@ -58,6 +58,9 @@ public class QueryNeo4JSolver extends StepSolver {
this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage());
}
+
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution;
}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
index 0360de4..40f82c3 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
@@ -20,8 +20,14 @@ public class StepSolver {
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
+
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+
return this.scenarioExecution;
};
+
+
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
this.scenarioExecution = scenarioExecution;
this.step = step;
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace b/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace
deleted file mode 100644
index 4378d1e..0000000
--- a/src/main/java/com/olympus/hermione/stepSolvers/apollo.code-workspace
+++ /dev/null
@@ -1,16 +0,0 @@
-{
- "folders": [
- {
- "path": "../../../../../../../../apollo"
- },
- {
- "name": "hermione",
- "path": "../../../../../../.."
- },
- {
- "name": "hermione-fe",
- "path": "../../../../../../../../hermione-fe"
- }
- ],
- "settings": {}
-}
\ No newline at end of file
diff --git a/src/main/resources/workflow.json b/src/main/resources/workflow.json
new file mode 100644
index 0000000..4d77fb8
--- /dev/null
+++ b/src/main/resources/workflow.json
@@ -0,0 +1,15 @@
+{
+ "id": "test-workflow",
+ "version": 1,
+ "steps": [
+ {
+ "id": "step1",
+ "stepType": "com.olympus.hermione.stepSolversV2.Task1",
+ "nextStepId": "step2"
+ },
+ {
+ "id": "step2",
+ "stepType": "com.olympus.hermione.stepSolversV2.Task2"
+ }
+ ]
+ }
\ No newline at end of file