From 427ebf0072862f45db0fa005428709d48dc20c59 Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Fri, 27 Sep 2024 09:24:04 +0200 Subject: [PATCH] refactor: Enable asynchronous execution in HermioneApplication This commit enables asynchronous execution in HermioneApplication by adding the @EnableAsync annotation. This allows for improved performance and responsiveness in the application. --- .../olympus/hermione/HermioneApplication.java | 2 + .../controllers/ExecutionController.java | 25 +++++ .../controllers/ScenarioController.java | 17 ++- .../hermione/models/ScenarioExecution.java | 5 +- .../services/ScenarioExecutionService.java | 103 +++++++++++++----- 5 files changed, 122 insertions(+), 30 deletions(-) create mode 100644 src/main/java/com/olympus/hermione/controllers/ExecutionController.java diff --git a/src/main/java/com/olympus/hermione/HermioneApplication.java b/src/main/java/com/olympus/hermione/HermioneApplication.java index 87fc3c6..25f53ef 100644 --- a/src/main/java/com/olympus/hermione/HermioneApplication.java +++ b/src/main/java/com/olympus/hermione/HermioneApplication.java @@ -2,8 +2,10 @@ package com.olympus.hermione; import org.springframework.boot.SpringApplication; import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.scheduling.annotation.EnableAsync; @SpringBootApplication +@EnableAsync public class HermioneApplication { public static void main(String[] args) { diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java new file mode 100644 index 0000000..a20746e --- /dev/null +++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java @@ -0,0 +1,25 @@ +package com.olympus.hermione.controllers; + +import java.util.Iterator; +import java.util.List; + +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.repository.ScenarioExecutionRepository; + +@RestController +public class ExecutionController { + @Autowired + ScenarioExecutionRepository scenarioExecutionRepository; + + @GetMapping("/execution") + public ScenarioExecution getOldExections(@RequestParam String id){ + + + return scenarioExecutionRepository.findById(id).get(); + } +} diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java index b68f02a..3d90a40 100644 --- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java +++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java @@ -39,10 +39,25 @@ public class ScenarioController { } @PostMapping("scenarios/execute") - public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { return scenarioExecutionService.executeScenario(scenarioExecutionInput); } + @PostMapping("scenarios/execute-async") + public ScenarioOutput executeScenarioAsync(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { + ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput); + + scenarioExecutionService.executeScenarioAsync(preOutput); + + return preOutput; + } + + @GetMapping("/scenarios/getExecutionProgress/{id}") + public ScenarioOutput getExecutionProgress(@PathVariable String id) { + return scenarioExecutionService.getExecutionProgress(id); + } + + @GetMapping("/scenarios/execute/{id}") public ScenarioExecution getScenarioExecution(@PathVariable String id) { return scenarioExecutionRepository.findById(id).get(); diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index 2511f3b..eb18abf 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -5,6 +5,8 @@ import java.util.HashMap; import org.springframework.data.annotation.Id; import org.springframework.data.mongodb.core.mapping.Document; +import com.olympus.hermione.dto.ScenarioExecutionInput; + import lombok.Getter; import lombok.Setter; @@ -24,5 +26,6 @@ public class ScenarioExecution { private String currentStepId; private String nextStepId; - + + private ScenarioExecutionInput scenarioExecutionInput; } diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 5c1f6da..e5b2fdd 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -9,6 +9,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; import com.olympus.hermione.dto.ScenarioExecutionInput; @@ -50,44 +51,97 @@ public class ScenarioExecutionService { private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); + public ScenarioOutput getExecutionProgress(String scenarioExecutionId){ + ScenarioOutput scenarioOutput = new ScenarioOutput(); + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId); + + if(o_scenarioExecution.isPresent()){ + ScenarioExecution scenarioExecution = o_scenarioExecution.get(); + + if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){ + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("OK"); + scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); + }else{ + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("IN_PROGRESS"); + } + + }else{ + scenarioOutput.setScenarioExecution_id(null); + scenarioOutput.setStatus("ERROR"); + scenarioOutput.setMessage("Scenario Execution not found"); + } + + return scenarioOutput; + } + + @Async + public void executeScenarioAsync(ScenarioOutput scenarioOutput){ + + String scenarioExecutionID = scenarioOutput.getScenarioExecution_id(); + + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID); + + if(o_scenarioExecution.isPresent()){ + ScenarioExecution scenarioExecution = o_scenarioExecution.get(); + + HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs(); + + Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId()); + + Scenario scenario = o_scenario.get(); + logger.info("Start execution of scenario: " + scenario.getName()); + + HashMap execSharedMap = new HashMap(); + execSharedMap.put("user_input", inputs); + scenarioExecution.setExecSharedMap(execSharedMap); + scenarioExecutionRepository.save(scenarioExecution); + + List steps = scenario.getSteps(); + String startStepId=scenario.getStartWithStepId(); + + ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); + + executeScenarioStep(startStep, scenarioExecution); + + while (scenarioExecution.getNextStepId()!=null) { + ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); + executeScenarioStep(step, scenarioExecution); + } + + + } + } - public String executeScenario(String scenarioId, String input){ + public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){ + + String scenarioId = scenarioExecutionInput.getScenario_id(); + HashMap inputs = scenarioExecutionInput.getInputs(); + ScenarioOutput scenarioOutput = new ScenarioOutput(); Optional o_scenario = scenarioRepository.findById(scenarioId); if(o_scenario.isPresent()){ Scenario scenario = o_scenario.get(); - logger.info("Executing scenario: " + scenario.getName()); + logger.info("Executing scenario: " + scenario.getName()); + ScenarioExecution scenarioExecution = new ScenarioExecution(); scenarioExecution.setScenario(scenario); - - HashMap execSharedMap = new HashMap(); - execSharedMap.put("user_input", input); - scenarioExecution.setExecSharedMap(execSharedMap); + scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput); scenarioExecutionRepository.save(scenarioExecution); - - List steps = scenario.getSteps(); - steps.sort(Comparator.comparingInt(ScenarioStep::getOrder)); + scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); + scenarioOutput.setStatus("STARTED"); - - - for (ScenarioStep step : steps) { - executeScenarioStep(step, scenarioExecution); - } - - return scenarioExecution.getExecSharedMap().get("scenario_output").toString(); - - }else{ - logger.error("Scenario not found with id: " + scenarioId); } - ; - return ""; + return scenarioOutput; } + public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){ String scenarioId = scenarioExecutionInput.getScenario_id(); @@ -109,9 +163,6 @@ public class ScenarioExecutionService { scenarioExecutionRepository.save(scenarioExecution); List steps = scenario.getSteps(); - - - String startStepId=scenario.getStartWithStepId(); ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); @@ -120,7 +171,6 @@ public class ScenarioExecutionService { while (scenarioExecution.getNextStepId()!=null) { ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); - executeScenarioStep(step, scenarioExecution); } @@ -165,15 +215,12 @@ public class ScenarioExecutionService { logger.info("Solving step: " + step.getStepId()); scenarioExecution.setCurrentStepId(step.getStepId()); - //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); logger.info("Step solved: " + step.getStepId()); logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); scenarioExecutionRepository.save(scenarioExecutionNew); - - }