Merge branch 'jworkflow' into 'master'
refactor: Enable asynchronous execution in HermioneApplication See merge request olympus_ai/hermione!1
This commit is contained in:
9
pom.xml
9
pom.xml
@@ -91,10 +91,11 @@
|
||||
<version>4.3.4</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.8</version>
|
||||
</dependency>
|
||||
<groupId>com.google.code.gson</groupId>
|
||||
<artifactId>gson</artifactId>
|
||||
<version>2.8.8</version>
|
||||
</dependency>
|
||||
|
||||
</dependencies>
|
||||
<dependencyManagement>
|
||||
<dependencies>
|
||||
|
||||
@@ -2,8 +2,10 @@ package com.olympus.hermione;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
import org.springframework.scheduling.annotation.EnableAsync;
|
||||
|
||||
@SpringBootApplication
|
||||
@EnableAsync
|
||||
public class HermioneApplication {
|
||||
|
||||
public static void main(String[] args) {
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.olympus.hermione.controllers;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.repository.ScenarioExecutionRepository;
|
||||
|
||||
@RestController
|
||||
public class ExecutionController {
|
||||
@Autowired
|
||||
ScenarioExecutionRepository scenarioExecutionRepository;
|
||||
|
||||
@GetMapping("/execution")
|
||||
public ScenarioExecution getOldExections(@RequestParam String id){
|
||||
|
||||
|
||||
return scenarioExecutionRepository.findById(id).get();
|
||||
}
|
||||
}
|
||||
@@ -39,10 +39,25 @@ public class ScenarioController {
|
||||
}
|
||||
|
||||
@PostMapping("scenarios/execute")
|
||||
public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
|
||||
public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
|
||||
return scenarioExecutionService.executeScenario(scenarioExecutionInput);
|
||||
}
|
||||
|
||||
@PostMapping("scenarios/execute-async")
|
||||
public ScenarioOutput executeScenarioAsync(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
|
||||
ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput);
|
||||
|
||||
scenarioExecutionService.executeScenarioAsync(preOutput);
|
||||
|
||||
return preOutput;
|
||||
}
|
||||
|
||||
@GetMapping("/scenarios/getExecutionProgress/{id}")
|
||||
public ScenarioOutput getExecutionProgress(@PathVariable String id) {
|
||||
return scenarioExecutionService.getExecutionProgress(id);
|
||||
}
|
||||
|
||||
|
||||
@GetMapping("/scenarios/execute/{id}")
|
||||
public ScenarioExecution getScenarioExecution(@PathVariable String id) {
|
||||
return scenarioExecutionRepository.findById(id).get();
|
||||
|
||||
@@ -3,8 +3,12 @@ package com.olympus.hermione.controllers;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
||||
import com.olympus.hermione.dto.ScenarioOutput;
|
||||
import com.olympus.hermione.services.Neo4JUitilityService;
|
||||
import com.olympus.hermione.services.ScenarioExecutionService;
|
||||
|
||||
@@ -14,20 +18,23 @@ public class TestController {
|
||||
@Autowired
|
||||
ScenarioExecutionService scenarioExecutionService;
|
||||
|
||||
|
||||
|
||||
@Autowired
|
||||
Neo4JUitilityService neo4JUitilityService;
|
||||
|
||||
@GetMapping("/test/scenario_execution")
|
||||
public String testScenarioExecution(){
|
||||
|
||||
String result = scenarioExecutionService.executeScenario("66aa162debe80dfcd17f0ef4","How i can change the path where uploaded documents are stored ?");
|
||||
return result;
|
||||
@PostMapping("test/execute")
|
||||
public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
|
||||
return scenarioExecutionService.executeScenario(scenarioExecutionInput);
|
||||
}
|
||||
|
||||
|
||||
|
||||
@GetMapping("/test/generate_schema_description")
|
||||
public String testGenerateSchemaDescription(){
|
||||
return neo4JUitilityService.generateSchemaDescription().toString();
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ public class Scenario {
|
||||
private String id;
|
||||
private String name;
|
||||
private String description;
|
||||
private String startWithStepId;
|
||||
private List<ScenarioStep> steps;
|
||||
private List<ScenarioInputs> inputs;
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ import java.util.HashMap;
|
||||
import org.springframework.data.annotation.Id;
|
||||
import org.springframework.data.mongodb.core.mapping.Document;
|
||||
|
||||
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@@ -21,5 +23,9 @@ public class ScenarioExecution {
|
||||
|
||||
private String scenarioId;
|
||||
|
||||
private String currentStepId;
|
||||
private String nextStepId;
|
||||
|
||||
|
||||
private ScenarioExecutionInput scenarioExecutionInput;
|
||||
}
|
||||
|
||||
@@ -12,6 +12,9 @@ public class ScenarioStep {
|
||||
private String description;
|
||||
private String name;
|
||||
private int order;
|
||||
|
||||
private String stepId;
|
||||
private String nextStepId;
|
||||
|
||||
private HashMap<String,Object> attributes;
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.scheduling.annotation.Async;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.olympus.hermione.dto.ScenarioExecutionInput;
|
||||
@@ -50,41 +51,97 @@ public class ScenarioExecutionService {
|
||||
|
||||
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
|
||||
|
||||
public ScenarioOutput getExecutionProgress(String scenarioExecutionId){
|
||||
ScenarioOutput scenarioOutput = new ScenarioOutput();
|
||||
Optional<ScenarioExecution> o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId);
|
||||
|
||||
if(o_scenarioExecution.isPresent()){
|
||||
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
|
||||
|
||||
if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){
|
||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||
scenarioOutput.setStatus("OK");
|
||||
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
|
||||
}else{
|
||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||
scenarioOutput.setStatus("IN_PROGRESS");
|
||||
}
|
||||
|
||||
}else{
|
||||
scenarioOutput.setScenarioExecution_id(null);
|
||||
scenarioOutput.setStatus("ERROR");
|
||||
scenarioOutput.setMessage("Scenario Execution not found");
|
||||
}
|
||||
|
||||
return scenarioOutput;
|
||||
}
|
||||
|
||||
@Async
|
||||
public void executeScenarioAsync(ScenarioOutput scenarioOutput){
|
||||
|
||||
String scenarioExecutionID = scenarioOutput.getScenarioExecution_id();
|
||||
|
||||
Optional<ScenarioExecution> o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID);
|
||||
|
||||
if(o_scenarioExecution.isPresent()){
|
||||
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
|
||||
|
||||
HashMap<String,String> inputs = scenarioExecution.getScenarioExecutionInput().getInputs();
|
||||
|
||||
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId());
|
||||
|
||||
Scenario scenario = o_scenario.get();
|
||||
logger.info("Start execution of scenario: " + scenario.getName());
|
||||
|
||||
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
|
||||
execSharedMap.put("user_input", inputs);
|
||||
scenarioExecution.setExecSharedMap(execSharedMap);
|
||||
scenarioExecutionRepository.save(scenarioExecution);
|
||||
|
||||
List<ScenarioStep> steps = scenario.getSteps();
|
||||
String startStepId=scenario.getStartWithStepId();
|
||||
|
||||
ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
|
||||
|
||||
executeScenarioStep(startStep, scenarioExecution);
|
||||
|
||||
while (scenarioExecution.getNextStepId()!=null) {
|
||||
ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
|
||||
executeScenarioStep(step, scenarioExecution);
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
public String executeScenario(String scenarioId, String input){
|
||||
public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){
|
||||
|
||||
String scenarioId = scenarioExecutionInput.getScenario_id();
|
||||
HashMap<String,String> inputs = scenarioExecutionInput.getInputs();
|
||||
ScenarioOutput scenarioOutput = new ScenarioOutput();
|
||||
|
||||
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioId);
|
||||
|
||||
if(o_scenario.isPresent()){
|
||||
Scenario scenario = o_scenario.get();
|
||||
logger.info("Executing scenario: " + scenario.getName());
|
||||
|
||||
logger.info("Executing scenario: " + scenario.getName());
|
||||
|
||||
ScenarioExecution scenarioExecution = new ScenarioExecution();
|
||||
scenarioExecution.setScenario(scenario);
|
||||
|
||||
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
|
||||
execSharedMap.put("user_input", input);
|
||||
scenarioExecution.setExecSharedMap(execSharedMap);
|
||||
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
|
||||
scenarioExecutionRepository.save(scenarioExecution);
|
||||
|
||||
List<ScenarioStep> steps = scenario.getSteps();
|
||||
|
||||
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
|
||||
for (ScenarioStep step : steps) {
|
||||
executeScenarioStep(step, scenarioExecution, step.getOrder());
|
||||
}
|
||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||
scenarioOutput.setStatus("STARTED");
|
||||
|
||||
return scenarioExecution.getExecSharedMap().get("scenario_output").toString();
|
||||
|
||||
}else{
|
||||
logger.error("Scenario not found with id: " + scenarioId);
|
||||
}
|
||||
|
||||
;
|
||||
return "";
|
||||
return scenarioOutput;
|
||||
}
|
||||
|
||||
|
||||
public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){
|
||||
|
||||
String scenarioId = scenarioExecutionInput.getScenario_id();
|
||||
@@ -106,10 +163,15 @@ public class ScenarioExecutionService {
|
||||
scenarioExecutionRepository.save(scenarioExecution);
|
||||
|
||||
List<ScenarioStep> steps = scenario.getSteps();
|
||||
String startStepId=scenario.getStartWithStepId();
|
||||
|
||||
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
|
||||
for (ScenarioStep step : steps) {
|
||||
executeScenarioStep(step, scenarioExecution, step.getOrder());
|
||||
ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
|
||||
|
||||
executeScenarioStep(startStep, scenarioExecution);
|
||||
|
||||
while (scenarioExecution.getNextStepId()!=null) {
|
||||
ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
|
||||
executeScenarioStep(step, scenarioExecution);
|
||||
}
|
||||
|
||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||
@@ -130,7 +192,7 @@ public class ScenarioExecutionService {
|
||||
|
||||
|
||||
|
||||
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
|
||||
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){
|
||||
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
|
||||
StepSolver solver=new StepSolver();
|
||||
switch (step.getType()) {
|
||||
@@ -147,17 +209,18 @@ public class ScenarioExecutionService {
|
||||
break;
|
||||
}
|
||||
|
||||
logger.info("Initializing step: " + step.getName());
|
||||
logger.info("Initializing step: " + step.getStepId());
|
||||
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
|
||||
|
||||
logger.info("Solving step: " + step.getName());
|
||||
logger.info("Solving step: " + step.getStepId());
|
||||
scenarioExecution.setCurrentStepId(step.getStepId());
|
||||
|
||||
//must add try catch in order to catch exceptions and step execution errors
|
||||
ScenarioExecution scenarioExecutionNew = solver.solveStep();
|
||||
|
||||
scenarioExecutionRepository.save(scenarioExecutionNew);
|
||||
|
||||
logger.info("Step solved: " + step.getStepId());
|
||||
logger.info("Next step: " + scenarioExecutionNew.getNextStepId());
|
||||
|
||||
scenarioExecutionRepository.save(scenarioExecutionNew);
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -49,6 +49,8 @@ public class BasicAIPromptSolver extends StepSolver {
|
||||
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
|
||||
loadParameters();
|
||||
|
||||
|
||||
@@ -64,7 +66,9 @@ public class BasicAIPromptSolver extends StepSolver {
|
||||
String output = response.get(0).getOutput().getContent();
|
||||
|
||||
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||
|
||||
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
|
||||
@@ -93,7 +93,9 @@ public class BasicQueryRagSolver extends StepSolver {
|
||||
//concatenate the content of the results into a single string
|
||||
String resultString = String.join("\n", result);
|
||||
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
|
||||
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
|
||||
@@ -58,6 +58,9 @@ public class QueryNeo4JSolver extends StepSolver {
|
||||
this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage());
|
||||
|
||||
}
|
||||
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
@@ -20,8 +20,14 @@ public class StepSolver {
|
||||
|
||||
public ScenarioExecution solveStep(){
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
return this.scenarioExecution;
|
||||
};
|
||||
|
||||
|
||||
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
|
||||
this.scenarioExecution = scenarioExecution;
|
||||
this.step = step;
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
{
|
||||
"folders": [
|
||||
{
|
||||
"path": "../../../../../../../../apollo"
|
||||
},
|
||||
{
|
||||
"name": "hermione",
|
||||
"path": "../../../../../../.."
|
||||
},
|
||||
{
|
||||
"name": "hermione-fe",
|
||||
"path": "../../../../../../../../hermione-fe"
|
||||
}
|
||||
],
|
||||
"settings": {}
|
||||
}
|
||||
15
src/main/resources/workflow.json
Normal file
15
src/main/resources/workflow.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"id": "test-workflow",
|
||||
"version": 1,
|
||||
"steps": [
|
||||
{
|
||||
"id": "step1",
|
||||
"stepType": "com.olympus.hermione.stepSolversV2.Task1",
|
||||
"nextStepId": "step2"
|
||||
},
|
||||
{
|
||||
"id": "step2",
|
||||
"stepType": "com.olympus.hermione.stepSolversV2.Task2"
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user