Merge branch 'jworkflow' into 'master'

refactor: Enable asynchronous execution in HermioneApplication

See merge request olympus_ai/hermione!1
This commit is contained in:
andrea.terzani
2024-09-27 07:45:46 +00:00
15 changed files with 192 additions and 55 deletions

View File

@@ -95,6 +95,7 @@
<artifactId>gson</artifactId> <artifactId>gson</artifactId>
<version>2.8.8</version> <version>2.8.8</version>
</dependency> </dependency>
</dependencies> </dependencies>
<dependencyManagement> <dependencyManagement>
<dependencies> <dependencies>

View File

@@ -2,8 +2,10 @@ package com.olympus.hermione;
import org.springframework.boot.SpringApplication; import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication; import org.springframework.boot.autoconfigure.SpringBootApplication;
import org.springframework.scheduling.annotation.EnableAsync;
@SpringBootApplication @SpringBootApplication
@EnableAsync
public class HermioneApplication { public class HermioneApplication {
public static void main(String[] args) { public static void main(String[] args) {

View File

@@ -0,0 +1,25 @@
package com.olympus.hermione.controllers;
import java.util.Iterator;
import java.util.List;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.repository.ScenarioExecutionRepository;
@RestController
public class ExecutionController {
@Autowired
ScenarioExecutionRepository scenarioExecutionRepository;
@GetMapping("/execution")
public ScenarioExecution getOldExections(@RequestParam String id){
return scenarioExecutionRepository.findById(id).get();
}
}

View File

@@ -43,6 +43,21 @@ public class ScenarioController {
return scenarioExecutionService.executeScenario(scenarioExecutionInput); return scenarioExecutionService.executeScenario(scenarioExecutionInput);
} }
@PostMapping("scenarios/execute-async")
public ScenarioOutput executeScenarioAsync(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput);
scenarioExecutionService.executeScenarioAsync(preOutput);
return preOutput;
}
@GetMapping("/scenarios/getExecutionProgress/{id}")
public ScenarioOutput getExecutionProgress(@PathVariable String id) {
return scenarioExecutionService.getExecutionProgress(id);
}
@GetMapping("/scenarios/execute/{id}") @GetMapping("/scenarios/execute/{id}")
public ScenarioExecution getScenarioExecution(@PathVariable String id) { public ScenarioExecution getScenarioExecution(@PathVariable String id) {
return scenarioExecutionRepository.findById(id).get(); return scenarioExecutionRepository.findById(id).get();

View File

@@ -3,8 +3,12 @@ package com.olympus.hermione.controllers;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne; import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne;
import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController; import org.springframework.web.bind.annotation.RestController;
import com.olympus.hermione.dto.ScenarioExecutionInput;
import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.services.Neo4JUitilityService; import com.olympus.hermione.services.Neo4JUitilityService;
import com.olympus.hermione.services.ScenarioExecutionService; import com.olympus.hermione.services.ScenarioExecutionService;
@@ -14,16 +18,19 @@ public class TestController {
@Autowired @Autowired
ScenarioExecutionService scenarioExecutionService; ScenarioExecutionService scenarioExecutionService;
@Autowired @Autowired
Neo4JUitilityService neo4JUitilityService; Neo4JUitilityService neo4JUitilityService;
@GetMapping("/test/scenario_execution")
public String testScenarioExecution(){
String result = scenarioExecutionService.executeScenario("66aa162debe80dfcd17f0ef4","How i can change the path where uploaded documents are stored ?"); @PostMapping("test/execute")
return result; public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
return scenarioExecutionService.executeScenario(scenarioExecutionInput);
} }
@GetMapping("/test/generate_schema_description") @GetMapping("/test/generate_schema_description")
public String testGenerateSchemaDescription(){ public String testGenerateSchemaDescription(){
return neo4JUitilityService.generateSchemaDescription().toString(); return neo4JUitilityService.generateSchemaDescription().toString();

View File

@@ -15,6 +15,7 @@ public class Scenario {
private String id; private String id;
private String name; private String name;
private String description; private String description;
private String startWithStepId;
private List<ScenarioStep> steps; private List<ScenarioStep> steps;
private List<ScenarioInputs> inputs; private List<ScenarioInputs> inputs;

View File

@@ -5,6 +5,8 @@ import java.util.HashMap;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document; import org.springframework.data.mongodb.core.mapping.Document;
import com.olympus.hermione.dto.ScenarioExecutionInput;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
@@ -21,5 +23,9 @@ public class ScenarioExecution {
private String scenarioId; private String scenarioId;
private String currentStepId;
private String nextStepId;
private ScenarioExecutionInput scenarioExecutionInput;
} }

View File

@@ -13,6 +13,9 @@ public class ScenarioStep {
private String name; private String name;
private int order; private int order;
private String stepId;
private String nextStepId;
private HashMap<String,Object> attributes; private HashMap<String,Object> attributes;
} }

View File

@@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioExecutionInput;
@@ -50,40 +51,96 @@ public class ScenarioExecutionService {
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
public ScenarioOutput getExecutionProgress(String scenarioExecutionId){
ScenarioOutput scenarioOutput = new ScenarioOutput();
Optional<ScenarioExecution> o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId);
if(o_scenarioExecution.isPresent()){
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
scenarioOutput.setStatus("OK");
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
}else{
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
scenarioOutput.setStatus("IN_PROGRESS");
}
}else{
scenarioOutput.setScenarioExecution_id(null);
scenarioOutput.setStatus("ERROR");
scenarioOutput.setMessage("Scenario Execution not found");
}
return scenarioOutput;
}
@Async
public void executeScenarioAsync(ScenarioOutput scenarioOutput){
String scenarioExecutionID = scenarioOutput.getScenarioExecution_id();
Optional<ScenarioExecution> o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID);
if(o_scenarioExecution.isPresent()){
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
HashMap<String,String> inputs = scenarioExecution.getScenarioExecutionInput().getInputs();
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId());
Scenario scenario = o_scenario.get();
logger.info("Start execution of scenario: " + scenario.getName());
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
execSharedMap.put("user_input", inputs);
scenarioExecution.setExecSharedMap(execSharedMap);
scenarioExecutionRepository.save(scenarioExecution);
List<ScenarioStep> steps = scenario.getSteps();
String startStepId=scenario.getStartWithStepId();
ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
executeScenarioStep(startStep, scenarioExecution);
while (scenarioExecution.getNextStepId()!=null) {
ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
executeScenarioStep(step, scenarioExecution);
}
public String executeScenario(String scenarioId, String input){ }
}
public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){
String scenarioId = scenarioExecutionInput.getScenario_id();
HashMap<String,String> inputs = scenarioExecutionInput.getInputs();
ScenarioOutput scenarioOutput = new ScenarioOutput();
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioId); Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioId);
if(o_scenario.isPresent()){ if(o_scenario.isPresent()){
Scenario scenario = o_scenario.get(); Scenario scenario = o_scenario.get();
logger.info("Executing scenario: " + scenario.getName()); logger.info("Executing scenario: " + scenario.getName());
ScenarioExecution scenarioExecution = new ScenarioExecution(); ScenarioExecution scenarioExecution = new ScenarioExecution();
scenarioExecution.setScenario(scenario); scenarioExecution.setScenario(scenario);
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
execSharedMap.put("user_input", input);
scenarioExecution.setExecSharedMap(execSharedMap);
scenarioExecutionRepository.save(scenarioExecution); scenarioExecutionRepository.save(scenarioExecution);
List<ScenarioStep> steps = scenario.getSteps(); scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
scenarioOutput.setStatus("STARTED");
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
for (ScenarioStep step : steps) {
executeScenarioStep(step, scenarioExecution, step.getOrder());
} }
return scenarioExecution.getExecSharedMap().get("scenario_output").toString(); return scenarioOutput;
}else{
logger.error("Scenario not found with id: " + scenarioId);
} }
;
return "";
}
public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){ public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){
@@ -106,10 +163,15 @@ public class ScenarioExecutionService {
scenarioExecutionRepository.save(scenarioExecution); scenarioExecutionRepository.save(scenarioExecution);
List<ScenarioStep> steps = scenario.getSteps(); List<ScenarioStep> steps = scenario.getSteps();
String startStepId=scenario.getStartWithStepId();
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
for (ScenarioStep step : steps) {
executeScenarioStep(step, scenarioExecution, step.getOrder()); executeScenarioStep(startStep, scenarioExecution);
while (scenarioExecution.getNextStepId()!=null) {
ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
executeScenarioStep(step, scenarioExecution);
} }
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
@@ -130,7 +192,7 @@ public class ScenarioExecutionService {
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){ private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
StepSolver solver=new StepSolver(); StepSolver solver=new StepSolver();
switch (step.getType()) { switch (step.getType()) {
@@ -147,18 +209,19 @@ public class ScenarioExecutionService {
break; break;
} }
logger.info("Initializing step: " + step.getName()); logger.info("Initializing step: " + step.getStepId());
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService); solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
logger.info("Solving step: " + step.getName()); logger.info("Solving step: " + step.getStepId());
scenarioExecution.setCurrentStepId(step.getStepId());
//must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep(); ScenarioExecution scenarioExecutionNew = solver.solveStep();
logger.info("Step solved: " + step.getStepId());
logger.info("Next step: " + scenarioExecutionNew.getNextStepId());
scenarioExecutionRepository.save(scenarioExecutionNew); scenarioExecutionRepository.save(scenarioExecutionNew);
} }
} }

View File

@@ -49,6 +49,8 @@ public class BasicAIPromptSolver extends StepSolver {
System.out.println("Solving step: " + this.step.getName()); System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters(); loadParameters();
@@ -65,6 +67,8 @@ public class BasicAIPromptSolver extends StepSolver {
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution; return this.scenarioExecution;
} }

View File

@@ -93,6 +93,8 @@ public class BasicQueryRagSolver extends StepSolver {
//concatenate the content of the results into a single string //concatenate the content of the results into a single string
String resultString = String.join("\n", result); String resultString = String.join("\n", result);
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString); this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution; return this.scenarioExecution;
} }

View File

@@ -59,6 +59,9 @@ public class QueryNeo4JSolver extends StepSolver {
} }
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution; return this.scenarioExecution;
} }

View File

@@ -20,8 +20,14 @@ public class StepSolver {
public ScenarioExecution solveStep(){ public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName()); System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
return this.scenarioExecution; return this.scenarioExecution;
}; };
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){ public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
this.scenarioExecution = scenarioExecution; this.scenarioExecution = scenarioExecution;
this.step = step; this.step = step;

View File

@@ -1,16 +0,0 @@
{
"folders": [
{
"path": "../../../../../../../../apollo"
},
{
"name": "hermione",
"path": "../../../../../../.."
},
{
"name": "hermione-fe",
"path": "../../../../../../../../hermione-fe"
}
],
"settings": {}
}

View File

@@ -0,0 +1,15 @@
{
"id": "test-workflow",
"version": 1,
"steps": [
{
"id": "step1",
"stepType": "com.olympus.hermione.stepSolversV2.Task1",
"nextStepId": "step2"
},
{
"id": "step2",
"stepType": "com.olympus.hermione.stepSolversV2.Task2"
}
]
}