diff --git a/pom.xml b/pom.xml
index ead9aa4..8657c2d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -28,7 +28,7 @@
21
- 1.0.0-M1
+ 1.0.0-SNAPSHOT
@@ -84,6 +84,19 @@
org.springframework.boot
spring-boot-starter-security
+
+
+ org.springframework.cloud
+ spring-cloud-starter-netflix-eureka-client
+ 4.1.3
+
+
+
+ org.springframework.cloud
+ spring-cloud-starter-feign
+ 1.4.7.RELEASE
+
+
org.antlr
diff --git a/src/main/java/com/olympus/hermione/config/Neo4jConfig.java b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java
index 9f0d8f2..e12120d 100644
--- a/src/main/java/com/olympus/hermione/config/Neo4jConfig.java
+++ b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java
@@ -3,8 +3,6 @@ package com.olympus.hermione.config;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
-import org.neo4j.driver.Session;
-import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.AuthTokens;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java
index a20746e..993bf76 100644
--- a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java
+++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java
@@ -1,8 +1,4 @@
package com.olympus.hermione.controllers;
-
-import java.util.Iterator;
-import java.util.List;
-
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestParam;
@@ -18,8 +14,6 @@ public class ExecutionController {
@GetMapping("/execution")
public ScenarioExecution getOldExections(@RequestParam String id){
-
-
return scenarioExecutionRepository.findById(id).get();
}
}
diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
index 3d90a40..49487f6 100644
--- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
+++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java
@@ -40,7 +40,19 @@ public class ScenarioController {
@PostMapping("scenarios/execute")
public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
- return scenarioExecutionService.executeScenario(scenarioExecutionInput);
+ ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput);
+ scenarioExecutionService.executeScenarioAsync(preOutput);
+
+ while(preOutput.getStatus().equals("IN_PROGRESS")) {
+ try {
+ Thread.sleep(1000);
+ preOutput = scenarioExecutionService.getExecutionProgress(preOutput.getScenarioExecution_id());
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+ return preOutput;
+
}
@PostMapping("scenarios/execute-async")
diff --git a/src/main/java/com/olympus/hermione/controllers/TestController.java b/src/main/java/com/olympus/hermione/controllers/TestController.java
index cc683b6..ca10329 100644
--- a/src/main/java/com/olympus/hermione/controllers/TestController.java
+++ b/src/main/java/com/olympus/hermione/controllers/TestController.java
@@ -1,40 +1,20 @@
package com.olympus.hermione.controllers;
import org.springframework.beans.factory.annotation.Autowired;
-import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne;
import org.springframework.web.bind.annotation.GetMapping;
-import org.springframework.web.bind.annotation.PostMapping;
-import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
-
-import com.olympus.hermione.dto.ScenarioExecutionInput;
-import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.services.Neo4JUitilityService;
-import com.olympus.hermione.services.ScenarioExecutionService;
@RestController
public class TestController {
- @Autowired
- ScenarioExecutionService scenarioExecutionService;
-
-
-
@Autowired
Neo4JUitilityService neo4JUitilityService;
- @PostMapping("test/execute")
- public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) {
- return scenarioExecutionService.executeScenario(scenarioExecutionInput);
- }
-
-
-
@GetMapping("/test/generate_schema_description")
public String testGenerateSchemaDescription(){
return neo4JUitilityService.generateSchemaDescription().toString();
}
-
}
diff --git a/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java b/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java
index 8f52ada..34fca54 100644
--- a/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java
+++ b/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java
@@ -11,4 +11,6 @@ public class ScenarioOutput {
private String scenarioExecution_id;
private String status;
private String message;
+ private String currentStepDescription;
+ private String currentStepId;
}
\ No newline at end of file
diff --git a/src/main/java/com/olympus/hermione/dto/aientity/Change.java b/src/main/java/com/olympus/hermione/dto/aientity/Change.java
new file mode 100644
index 0000000..06d9f14
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/dto/aientity/Change.java
@@ -0,0 +1,93 @@
+package com.olympus.hermione.dto.aientity;
+
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+public class Change {
+ private String filename;
+ private String classname;
+ private String methodname;
+ @JsonProperty("new_code")
+ private String newCode;
+ @JsonProperty("previous_code")
+ private String previousCode;
+ @JsonProperty("change_description")
+ private String changeDescription;
+
+ // Constructors
+ public Change() {}
+
+
+ public Change(String filename, String classname, String methodname, String newCode, String previousCode, String changeDescription) {
+ this.filename = filename;
+ this.classname = classname;
+ this.methodname = methodname;
+ this.newCode = newCode;
+ this.previousCode = previousCode;
+ this.changeDescription = changeDescription;
+ }
+
+ // Getters and Setters
+
+
+
+ public String getFilename() {
+ return filename;
+ }
+
+ public void setFilename(String filename) {
+ this.filename = filename;
+ }
+
+ public String getClassname() {
+ return classname;
+ }
+
+ public void setClassname(String classname) {
+ this.classname = classname;
+ }
+
+ public String getMethodname() {
+ return methodname;
+ }
+
+ public void setMethodname(String methodname) {
+ this.methodname = methodname;
+ }
+
+ public String getNewCode() {
+ return newCode;
+ }
+
+ public void setNewCode(String newCode) {
+ this.newCode = newCode;
+ }
+
+ public String getPreviousCode() {
+ return previousCode;
+ }
+
+ public void setPreviousCode(String previousCode) {
+ this.previousCode = previousCode;
+ }
+
+ public String getChangeDescription() {
+ return changeDescription;
+ }
+
+ public void setChangeDescription(String changeDescription) {
+ this.changeDescription = changeDescription;
+ }
+
+ // toString Method
+ @Override
+ public String toString() {
+ return "Change{" +
+ "filename='" + filename + '\'' +
+ ", classname='" + classname + '\'' +
+ ", methodname='" + methodname + '\'' +
+ ", newCode='" + newCode + '\'' +
+ ", previousCode='" + previousCode + '\'' +
+ ", changeDescription='" + changeDescription + '\'' +
+ '}';
+ }
+}
diff --git a/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java b/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java
new file mode 100644
index 0000000..de0f2af
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java
@@ -0,0 +1,66 @@
+package com.olympus.hermione.dto.aientity;
+
+import java.util.List;
+
+
+public class CiaOutputEntity {
+ private String application;
+ private String title;
+ private String description;
+ private List changes;
+
+ // Constructors
+ public CiaOutputEntity() {}
+
+ public CiaOutputEntity(String application, String title, String description, List changes) {
+ this.application = application;
+ this.title = title;
+ this.description = description;
+ this.changes = changes;
+ }
+
+ public void setApplication(String application) {
+ this.application = application;
+ }
+
+ public String getApplication() {
+ return application;
+ }
+
+ // Getters and Setters
+ public String getTitle() {
+ return title;
+ }
+
+ public void setTitle(String title) {
+ this.title = title;
+ }
+
+ public String getDescription() {
+ return description;
+ }
+
+ public void setDescription(String description) {
+ this.description = description;
+ }
+
+ public List getChanges() {
+ return changes;
+ }
+
+ public void setChanges(List changes) {
+ this.changes = changes;
+ }
+
+ // toString Method
+ @Override
+ public String toString() {
+ return "CodeChangeRequest{" +
+ "title='" + title + '\'' +
+ ", description='" + description + '\'' +
+ ", changes=" + changes ;
+
+
+
+ }
+}
diff --git a/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java b/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java
new file mode 100644
index 0000000..8f74b81
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java
@@ -0,0 +1,16 @@
+package com.olympus.hermione.dto.external;
+
+import lombok.Getter;
+import lombok.Setter;
+
+@Getter
+@Setter
+public class CodeRagResponse {
+
+ private String fullyQualifiedName;
+ private String code;
+ private String description;
+ private String similarityScore;
+ private String documentId;
+ private String codeType;
+}
diff --git a/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java b/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java
new file mode 100644
index 0000000..51e069a
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java
@@ -0,0 +1,23 @@
+package com.olympus.hermione.dto.external;
+
+import lombok.Getter;
+import lombok.Setter;
+
+@Getter
+@Setter
+public class SimilaritySearchCodeInput {
+
+ private String query;
+ private String topK;
+ private String similarityThreshold;
+ private String filterExpression;
+
+}
+
+
+/*
+ * .withQuery(query)
+ .withTopK(3)
+ .withSimilarityThreshold(0.8)
+ .withFilterExpression(b.eq("doc_type","GeneratedMethodDoc").build())
+ */
\ No newline at end of file
diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java
index 43ce81a..979216f 100644
--- a/src/main/java/com/olympus/hermione/models/Scenario.java
+++ b/src/main/java/com/olympus/hermione/models/Scenario.java
@@ -19,4 +19,5 @@ public class Scenario {
private List steps;
private List inputs;
+ private boolean useChatMemory=false;
}
diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
index eb18abf..475fdf2 100644
--- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
+++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java
@@ -25,6 +25,7 @@ public class ScenarioExecution {
private String currentStepId;
private String nextStepId;
+ private String currentStepDescription;
private ScenarioExecutionInput scenarioExecutionInput;
diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
index e5b2fdd..2f1ea67 100644
--- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
+++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
@@ -1,14 +1,20 @@
package com.olympus.hermione.services;
-import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
+import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
+import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
+
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
+import org.springframework.cloud.client.discovery.DiscoveryClient;
import org.springframework.scheduling.annotation.Async;
import org.springframework.stereotype.Service;
@@ -19,9 +25,11 @@ import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
+import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
+import com.olympus.hermione.stepSolvers.SourceCodeRagSolver;
import com.olympus.hermione.stepSolvers.StepSolver;
import org.neo4j.driver.Driver;
@@ -38,6 +46,9 @@ public class ScenarioExecutionService {
@Autowired
VectorStore vectorStore;
+
+ @Autowired
+ DiscoveryClient discoveryClient;
@Autowired
ChatModel chatModel;
@@ -48,6 +59,8 @@ public class ScenarioExecutionService {
@Autowired
Neo4JUitilityService neo4JUitilityService;
+ private ChatClient chatClient;
+
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
@@ -57,6 +70,7 @@ public class ScenarioExecutionService {
if(o_scenarioExecution.isPresent()){
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
+ scenarioExecution.setCurrentStepId(scenarioExecution.getCurrentStepId());
if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
@@ -85,12 +99,11 @@ public class ScenarioExecutionService {
if(o_scenarioExecution.isPresent()){
ScenarioExecution scenarioExecution = o_scenarioExecution.get();
-
HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs();
-
Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId());
Scenario scenario = o_scenario.get();
+
logger.info("Start execution of scenario: " + scenario.getName());
HashMap execSharedMap = new HashMap();
@@ -98,6 +111,31 @@ public class ScenarioExecutionService {
scenarioExecution.setExecSharedMap(execSharedMap);
scenarioExecutionRepository.save(scenarioExecution);
+ //TODO: Chatmodel should be initialized with the scenario specific infos
+
+ if(scenario.isUseChatMemory()){
+ logger.info("Initializing chatClient with chat-memory advisor");
+
+ ChatMemory chatMemory = new InMemoryChatMemory();
+ chatClient = ChatClient.builder(chatModel)
+ .defaultAdvisors(
+ new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor
+ new SimpleLoggerAdvisor()
+ )
+ .build();
+
+ }else{
+ logger.info("Initializing chatClient with simple logger advisor");
+
+ chatClient = ChatClient.builder(chatModel)
+ .defaultAdvisors(
+ new SimpleLoggerAdvisor()
+ )
+ .build();
+ }
+
+
+
List steps = scenario.getSteps();
String startStepId=scenario.getStartWithStepId();
@@ -115,10 +153,56 @@ public class ScenarioExecutionService {
}
+
+
+
+ private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){
+ logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
+ StepSolver solver=new StepSolver();
+ switch (step.getType()) {
+ case "ADVANCED_QUERY_AI":
+ solver = new AdvancedAIPromptSolver();
+ break;
+ case "RAG_QUERY":
+ solver = new BasicQueryRagSolver();
+ break;
+ case "SIMPLE_QUERY_AI":
+ solver = new BasicAIPromptSolver();
+ break;
+ case "QUERY_GRAPH":
+ solver = new QueryNeo4JSolver();
+ break;
+ case "RAG_SOURCE_CODE":
+ solver = new SourceCodeRagSolver();
+ break;
+
+ default:
+ break;
+ }
+
+
+ logger.info("Initializing step: " + step.getStepId());
+ solver.init(step, scenarioExecution,
+ vectorStore,chatModel,chatClient,graphDriver,neo4JUitilityService
+ ,discoveryClient);
+
+ logger.info("Solving step: " + step.getStepId());
+ scenarioExecution.setCurrentStepId(step.getStepId());
+ scenarioExecution.setCurrentStepDescription(step.getName());
+
+ ScenarioExecution scenarioExecutionNew = solver.solveStep();
+
+ logger.info("Step solved: " + step.getStepId());
+ logger.info("Next step: " + scenarioExecutionNew.getNextStepId());
+
+ scenarioExecutionRepository.save(scenarioExecutionNew);
+
+ }
+
+
public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){
String scenarioId = scenarioExecutionInput.getScenario_id();
- HashMap inputs = scenarioExecutionInput.getInputs();
ScenarioOutput scenarioOutput = new ScenarioOutput();
Optional o_scenario = scenarioRepository.findById(scenarioId);
@@ -134,94 +218,11 @@ public class ScenarioExecutionService {
scenarioExecutionRepository.save(scenarioExecution);
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
- scenarioOutput.setStatus("STARTED");
+ scenarioOutput.setStatus("IN_PROGRESS");
}
return scenarioOutput;
}
-
- public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){
-
- String scenarioId = scenarioExecutionInput.getScenario_id();
- HashMap inputs = scenarioExecutionInput.getInputs();
- ScenarioOutput scenarioOutput = new ScenarioOutput();
-
- Optional o_scenario = scenarioRepository.findById(scenarioId);
-
- if(o_scenario.isPresent()){
- Scenario scenario = o_scenario.get();
- logger.info("Executing scenario: " + scenario.getName());
-
- ScenarioExecution scenarioExecution = new ScenarioExecution();
- scenarioExecution.setScenario(scenario);
-
- HashMap execSharedMap = new HashMap();
- execSharedMap.put("user_input", inputs);
- scenarioExecution.setExecSharedMap(execSharedMap);
- scenarioExecutionRepository.save(scenarioExecution);
-
- List steps = scenario.getSteps();
- String startStepId=scenario.getStartWithStepId();
-
- ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
-
- executeScenarioStep(startStep, scenarioExecution);
-
- while (scenarioExecution.getNextStepId()!=null) {
- ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
- executeScenarioStep(step, scenarioExecution);
- }
-
- scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
- scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
- scenarioOutput.setStatus("OK");
-
- return scenarioOutput;
-
- }else{
- logger.error("Scenario not found with id: " + scenarioId);
- scenarioOutput.setScenarioExecution_id(null);
- scenarioOutput.setStatus("ERROR");
- scenarioOutput.setMessage("Scenario not found");
- }
-
- return scenarioOutput;
- }
-
-
-
- private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){
- logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
- StepSolver solver=new StepSolver();
- switch (step.getType()) {
- case "RAG_QUERY":
- solver = new BasicQueryRagSolver();
- break;
- case "SIMPLE_QUERY_AI":
- solver = new BasicAIPromptSolver();
- break;
- case "QUERY_GRAPH":
- solver = new QueryNeo4JSolver();
- break;
- default:
- break;
- }
-
- logger.info("Initializing step: " + step.getStepId());
- solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
-
- logger.info("Solving step: " + step.getStepId());
- scenarioExecution.setCurrentStepId(step.getStepId());
-
- ScenarioExecution scenarioExecutionNew = solver.solveStep();
-
- logger.info("Step solved: " + step.getStepId());
- logger.info("Next step: " + scenarioExecutionNew.getNextStepId());
-
- scenarioExecutionRepository.save(scenarioExecutionNew);
-
- }
-
}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java
new file mode 100644
index 0000000..75ae268
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java
@@ -0,0 +1,95 @@
+package com.olympus.hermione.stepSolvers;
+
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
+import org.springframework.ai.chat.messages.Message;
+import org.springframework.ai.chat.messages.SystemMessage;
+import org.springframework.ai.chat.messages.UserMessage;
+import com.fasterxml.jackson.core.JsonProcessingException;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.olympus.hermione.dto.aientity.CiaOutputEntity;
+import com.olympus.hermione.models.ScenarioExecution;
+import com.olympus.hermione.utility.AttributeParser;
+
+import ch.qos.logback.classic.Logger;
+
+
+public class AdvancedAIPromptSolver extends StepSolver {
+
+ private String qai_system_prompt_template;
+ private String qai_user_input;
+ private String qai_output_variable;
+ private boolean qai_load_graph_schema=false;
+ private String qai_output_entityType;
+
+
+ Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
+
+ private void loadParameters(){
+ logger.info("Loading parameters");
+
+ if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
+ this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
+ logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
+ }
+
+ if(this.qai_load_graph_schema){
+ logger.info("Loading graph schema");
+ this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
+ }
+
+ AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
+
+ this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
+ this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
+
+ this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
+ this.qai_output_entityType = (String) this.step.getAttributes().get("qai_output_entityType");
+
+ //TODO: Add memory ID attribute to have the possibility of multiple conversations
+
+ }
+
+ @Override
+ public ScenarioExecution solveStep(){
+
+ System.out.println("Solving step: " + this.step.getName());
+
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+
+ loadParameters();
+
+ String userText = this.qai_user_input;
+
+ Message userMessage = new UserMessage(userText);
+ Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
+
+ CallResponseSpec resp = chatClient.prompt()
+ .messages(userMessage,systemMessage)
+ .advisors(advisor -> advisor.param("chat_memory_conversation_id", this.scenarioExecution.getId())
+ .param("chat_memory_response_size", 100))
+ .call();
+
+
+ if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
+ logger.info("Output is of type CiaOutputEntity");
+ CiaOutputEntity ouputEntity = resp.entity(CiaOutputEntity.class);
+ try {
+ ObjectMapper objectMapper = new ObjectMapper();
+ String jsonOutput = objectMapper.writeValueAsString(ouputEntity);
+ this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, jsonOutput);
+ } catch (JsonProcessingException e) {
+ logger.error("Error converting output entity to JSON", e);
+ }
+ }else{
+
+ String output = resp.content();
+ this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
+ }
+
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+
+ return this.scenarioExecution;
+ }
+
+}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
index 137471a..524d6fc 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java
@@ -53,10 +53,7 @@ public class BasicAIPromptSolver extends StepSolver {
loadParameters();
-
- String userText = this.qai_user_input;
-
- Message userMessage = new UserMessage(userText);
+ Message userMessage = new UserMessage(this.qai_user_input);
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java
new file mode 100644
index 0000000..dbacaea
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java
@@ -0,0 +1,99 @@
+package com.olympus.hermione.stepSolvers;
+
+
+
+import org.slf4j.LoggerFactory;
+import org.springframework.cloud.client.ServiceInstance;
+import org.springframework.web.client.RestTemplate;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.olympus.hermione.dto.external.CodeRagResponse;
+import com.olympus.hermione.dto.external.SimilaritySearchCodeInput;
+import com.olympus.hermione.models.ScenarioExecution;
+import com.olympus.hermione.utility.AttributeParser;
+
+import ch.qos.logback.classic.Logger;
+
+public class SourceCodeRagSolver extends StepSolver {
+
+ Logger logger = (Logger) LoggerFactory.getLogger(SourceCodeRagSolver.class);
+
+
+ private String query;
+ private String rag_filter;
+ private int topk;
+ private double threshold;
+ private String outputField;
+
+
+
+ private void loadParameters(){
+ logger.info("Loading parameters");
+
+ AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
+
+ this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input"));
+
+ this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter"));
+
+ if(this.step.getAttributes().containsKey("rag_topk")){
+ this.topk = (int) this.step.getAttributes().get("rag_topk");
+ }else{
+ this.topk = 4;
+ }
+
+ if(this.step.getAttributes().containsKey("rag_threshold")){
+ this.threshold = (double) this.step.getAttributes().get("rag_threshold");
+ }else{
+ this.threshold = 0.8;
+ }
+
+ this.outputField = (String) this.step.getAttributes().get("rag_output_variable");
+ }
+
+ private void logParameters(){
+ logger.info("query: " + this.query);
+ logger.info("rag_filter: " + this.rag_filter);
+ logger.info("topk: " + this.topk);
+ logger.info("threshold: " + this.threshold);
+ logger.info("outputField: " + this.outputField);
+ }
+
+
+ @Override
+ public ScenarioExecution solveStep() {
+
+ logger.info("Solving step: " + this.step.getName());
+ loadParameters();
+ logParameters();
+
+ SimilaritySearchCodeInput similaritySearchCodeInput = new SimilaritySearchCodeInput();
+ similaritySearchCodeInput.setQuery(this.query);
+ similaritySearchCodeInput.setTopK(String.valueOf(this.topk));
+ similaritySearchCodeInput.setSimilarityThreshold(String.valueOf(this.threshold));
+ similaritySearchCodeInput.setFilterExpression(this.rag_filter);
+
+ ServiceInstance serviceInstance = discoveryClient.getInstances("java-source-code-service").get(0);
+ RestTemplate restTemplate = new RestTemplate();
+
+ CodeRagResponse[] ragresponse =
+ restTemplate.postForEntity(serviceInstance.getUri() + "/similarity-search-code",
+ similaritySearchCodeInput,CodeRagResponse[].class ).getBody();
+
+ String code="";
+ for (CodeRagResponse codeRagResponse : ragresponse) {
+ code += "SOURCE CODE OF "+codeRagResponse.getCodeType()+" : " + codeRagResponse.getFullyQualifiedName() +" \n";
+ code += codeRagResponse.getCode() + "\n";
+ code += "-----------------------------------\n";
+ }
+
+ this.scenarioExecution.getExecSharedMap().put(this.outputField, code);
+
+ //concatenate the content of the results into a single string
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+
+ return this.scenarioExecution;
+ }
+
+}
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
index 40f82c3..e18f343 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java
@@ -1,8 +1,13 @@
package com.olympus.hermione.stepSolvers;
import org.neo4j.driver.Driver;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
+import org.springframework.cloud.client.discovery.DiscoveryClient;
+
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.services.Neo4JUitilityService;
@@ -16,11 +21,14 @@ public class StepSolver {
protected ChatModel chatModel;
protected Driver graphDriver;
protected Neo4JUitilityService neo4JUitilityService;
+ protected ChatClient chatClient;
+ protected DiscoveryClient discoveryClient;
+
+ private Logger logger = LoggerFactory.getLogger(StepSolver.class);
public ScenarioExecution solveStep(){
- System.out.println("Solving step: " + this.step.getName());
-
+ logger.info("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
@@ -28,13 +36,17 @@ public class StepSolver {
};
- public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
+ public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,ChatClient chatClient, Driver graphDriver,Neo4JUitilityService neo4JUitilityService,
+ DiscoveryClient discoveryClient){
+ logger.info("Initializing StepSolver");
this.scenarioExecution = scenarioExecution;
+ this.chatClient = chatClient;
this.step = step;
this.vectorStore = vectorStore;
this.chatModel = chatModel;
this.graphDriver = graphDriver;
this.neo4JUitilityService = neo4JUitilityService;
- System.out.println("Initializing StepSolver");
+ this.discoveryClient = discoveryClient;
+ logger.info("StepSolver initialized");
}
}
\ No newline at end of file
diff --git a/src/main/java/com/olympus/hermione/utility/AttributeParser.java b/src/main/java/com/olympus/hermione/utility/AttributeParser.java
index 096ba6e..0bfd3aa 100644
--- a/src/main/java/com/olympus/hermione/utility/AttributeParser.java
+++ b/src/main/java/com/olympus/hermione/utility/AttributeParser.java
@@ -16,8 +16,10 @@ public class AttributeParser {
ST template = new ST(attribute_value, '$', '$');
template.add("shared", this.scenario_execution.getExecSharedMap());
+ String output = template.render();
- return template.render();
+
+ return output;
}
}
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
index ce54ede..6557520 100644
--- a/src/main/resources/application.properties
+++ b/src/main/resources/application.properties
@@ -16,4 +16,20 @@ spring.main.allow-circular-references=true
neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687
neo4j.username=neo4j
-neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ
\ No newline at end of file
+neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ
+
+spring.neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687
+spring.neo4j.authentication.username=neo4j
+spring.neo4j.authentication.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ
+
+
+spring.ai.vectorstore.neo4j.database-name:neo4j
+spring.ai.vectorstore.neo4j.initialize-schema:true
+spring.ai.vectorstore.neo4j.label:Document
+spring.ai.vectorstore.neo4j.embedding-property:embedding
+spring.ai.vectorstore.neo4j.index-name:spring-ai-document-index
+
+
+spring.main.allow-bean-definition-overriding=true
+logging.level.org.springframework.ai.chat.client.advisor=DEBUG
+
diff --git a/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java b/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java
index f1e0ef4..687decb 100644
--- a/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java
+++ b/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java
@@ -1,14 +1,11 @@
package com.olympus.hermione;
import static org.junit.jupiter.api.Assertions.assertTrue;
-import org.junit.jupiter.api.*;
-import static org.junit.jupiter.api.Assertions.*;
import com.google.gson.JsonArray;
import com.google.gson.JsonObject;
import com.olympus.hermione.services.Neo4JUitilityService;
-import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
@@ -24,7 +21,6 @@ public class Neo4jSchemaDescriptionTest {
}
-
@Test
void testGenerateSchemaDescription() {