diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 7b5d914..44412f0 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -67,11 +67,10 @@ import org.bson.types.ObjectId; import org.neo4j.driver.Driver; import org.slf4j.Logger; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; -import com.olympus.hermione.stepSolvers.ExternalAgentSolver; import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver; +import com.olympus.hermione.stepSolvers.OlympusAgentSolver; import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver; @Service @@ -265,12 +264,12 @@ public class ScenarioExecutionService { case "RAG_SOURCE_CODE": solver = new SourceCodeRagSolver(); break; - case "EXTERNAL_AGENT": - solver = new ExternalAgentSolver(); - break; case "EXTERNAL_CODEGENIE": solver = new ExternalCodeGenieSolver(); break; + case "OLYMPUS_AGENT": + solver = new OlympusAgentSolver(); + break; case "SUMMARIZE_DOC": solver = new SummarizeDocSolver(); break; @@ -291,6 +290,9 @@ public class ScenarioExecutionService { logger.info("Solving step: " + step.getStepId()); scenarioExecution.setCurrentStepId(step.getStepId()); scenarioExecution.setCurrentStepDescription(step.getName()); + + // Save immediately so frontend can show correct current step + scenarioExecutionRepository.save(scenarioExecution); ScenarioExecution scenarioExecutionNew = scenarioExecution; diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlympusAgentSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlympusAgentSolver.java new file mode 100644 index 0000000..bf17cdb --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlympusAgentSolver.java @@ -0,0 +1,164 @@ +package com.olympus.hermione.stepSolvers; + +import org.json.JSONArray; +import org.json.JSONObject; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class OlympusAgentSolver extends StepSolver { + + private String agent_id; + private String agent_message; + private String agent_input; + private String agent_context; + private String agent_base_url; + private String agent_output_variable; + + Logger logger = (Logger) LoggerFactory.getLogger(OlympusAgentSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters for LangGraph Agent"); + + // Agent ID (required) + if(this.step.getAttributes().get("agent_id") != null){ + this.agent_id = (String) this.step.getAttributes().get("agent_id"); + logger.info("agent_id: " + this.agent_id); + } else { + throw new IllegalArgumentException("agent_id is required"); + } + + // Message/Input + if(this.step.getAttributes().get("agent_message") != null){ + this.agent_message = (String) this.step.getAttributes().get("agent_message"); + logger.info("agent_message: " + this.agent_message); + } + + // Agent Input + if(this.step.getAttributes().get("agent_input") != null){ + this.agent_input = (String) this.step.getAttributes().get("agent_input"); + logger.info("agent_input: " + this.agent_input); + } + + // Context (optional JSON) + if(this.step.getAttributes().get("agent_context") != null){ + this.agent_context = (String) this.step.getAttributes().get("agent_context"); + logger.info("agent_context: " + this.agent_context); + } + + // Base URL (default to localhost) + if(this.step.getAttributes().get("agent_base_url") != null){ + this.agent_base_url = (String) this.step.getAttributes().get("agent_base_url"); + } + logger.info("agent_base_url: " + this.agent_base_url); + + // Output variable name + if(this.step.getAttributes().get("agent_output_variable") != null){ + this.agent_output_variable = (String) this.step.getAttributes().get("agent_output_variable"); + } + logger.info("agent_output_variable: " + this.agent_output_variable); + + + // Parse variables from execution context + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + if(this.agent_message != null){ + this.agent_message = attributeParser.parse(this.agent_message); + } + if(this.agent_input != null){ + this.agent_input = attributeParser.parse(this.agent_input); + } + if(this.agent_context != null){ + this.agent_context = attributeParser.parse(this.agent_context); + } + } + + @Override + public ScenarioExecution solveStep() throws Exception { + + logger.info("Solving LangGraph Agent step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + + RestTemplate restTemplate = new RestTemplate(); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + + String endpoint; + JSONObject requestBody = new JSONObject(); + + endpoint = this.agent_base_url + "/agent/" + this.agent_id + "/execute"; + + // Add agent input (required) + String inputValue = this.agent_input != null ? this.agent_input : + (this.agent_message != null ? this.agent_message : ""); + requestBody.put("agent_input", inputValue); + + // Add context if provided + if(this.agent_context != null && !this.agent_context.isEmpty()){ + try { + JSONObject contextObj = new JSONObject(this.agent_context); + requestBody.put("context", contextObj); + } catch (Exception e) { + logger.warn("Failed to parse agent_context as JSON, using empty object", e); + requestBody.put("context", new JSONObject()); + } + } else { + requestBody.put("context", new JSONObject()); + } + + // Add scenario ID from execution + requestBody.put("scenario_id", this.scenarioExecution.getScenario().getId()); + + + logger.info("Calling LangGraph Agent endpoint: " + endpoint); + logger.info("Request body: " + requestBody.toString()); + + HttpEntity request = new HttpEntity<>(requestBody.toString(), headers); + + try { + ResponseEntity response = restTemplate.exchange( + endpoint, + HttpMethod.POST, + request, + String.class + ); + + JSONObject jsonResponse = new JSONObject(response.getBody()); + + logger.info("Hermione execution completed"); + logger.info("Execution ID: " + jsonResponse.optString("execution_id")); + logger.info("Final output: " + jsonResponse.optString("final_output")); + + // Store the complete response + this.scenarioExecution.getExecSharedMap().put(this.agent_output_variable, jsonResponse.toString()); + + // Also store final output separately for easy access + this.scenarioExecution.getExecSharedMap().put( + this.agent_output_variable + "_final_output", + jsonResponse.optString("final_output", "") + ); + + + // Move to next step + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + } catch (Exception e) { + logger.error("Error calling LangGraph Agent service", e); + throw new Exception("LangGraph Agent execution failed: " + e.getMessage()); + } + + return this.scenarioExecution; + } +}