diff --git a/pom.xml b/pom.xml index ead9aa4..8657c2d 100644 --- a/pom.xml +++ b/pom.xml @@ -28,7 +28,7 @@ 21 - 1.0.0-M1 + 1.0.0-SNAPSHOT @@ -84,6 +84,19 @@ org.springframework.boot spring-boot-starter-security + + + org.springframework.cloud + spring-cloud-starter-netflix-eureka-client + 4.1.3 + + + + org.springframework.cloud + spring-cloud-starter-feign + 1.4.7.RELEASE + + org.antlr diff --git a/src/main/java/com/olympus/hermione/config/Neo4jConfig.java b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java index 9f0d8f2..e12120d 100644 --- a/src/main/java/com/olympus/hermione/config/Neo4jConfig.java +++ b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java @@ -3,8 +3,6 @@ package com.olympus.hermione.config; import org.neo4j.driver.Driver; import org.neo4j.driver.GraphDatabase; -import org.neo4j.driver.Session; -import org.neo4j.driver.SessionConfig; import org.neo4j.driver.AuthTokens; import org.springframework.beans.factory.annotation.Value; import org.springframework.context.annotation.Bean; diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java index a20746e..993bf76 100644 --- a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java +++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java @@ -1,8 +1,4 @@ package com.olympus.hermione.controllers; - -import java.util.Iterator; -import java.util.List; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.RequestParam; @@ -18,8 +14,6 @@ public class ExecutionController { @GetMapping("/execution") public ScenarioExecution getOldExections(@RequestParam String id){ - - return scenarioExecutionRepository.findById(id).get(); } } diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java index 3d90a40..49487f6 100644 --- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java +++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java @@ -40,7 +40,19 @@ public class ScenarioController { @PostMapping("scenarios/execute") public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { - return scenarioExecutionService.executeScenario(scenarioExecutionInput); + ScenarioOutput preOutput = scenarioExecutionService.prepareScrenarioExecution(scenarioExecutionInput); + scenarioExecutionService.executeScenarioAsync(preOutput); + + while(preOutput.getStatus().equals("IN_PROGRESS")) { + try { + Thread.sleep(1000); + preOutput = scenarioExecutionService.getExecutionProgress(preOutput.getScenarioExecution_id()); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + return preOutput; + } @PostMapping("scenarios/execute-async") diff --git a/src/main/java/com/olympus/hermione/controllers/TestController.java b/src/main/java/com/olympus/hermione/controllers/TestController.java index cc683b6..ca10329 100644 --- a/src/main/java/com/olympus/hermione/controllers/TestController.java +++ b/src/main/java/com/olympus/hermione/controllers/TestController.java @@ -1,40 +1,20 @@ package com.olympus.hermione.controllers; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.data.mongodb.core.aggregation.ComparisonOperators.Ne; import org.springframework.web.bind.annotation.GetMapping; -import org.springframework.web.bind.annotation.PostMapping; -import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; - -import com.olympus.hermione.dto.ScenarioExecutionInput; -import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.services.Neo4JUitilityService; -import com.olympus.hermione.services.ScenarioExecutionService; @RestController public class TestController { - @Autowired - ScenarioExecutionService scenarioExecutionService; - - - @Autowired Neo4JUitilityService neo4JUitilityService; - @PostMapping("test/execute") - public ScenarioOutput executeScenario(@RequestBody ScenarioExecutionInput scenarioExecutionInput) { - return scenarioExecutionService.executeScenario(scenarioExecutionInput); - } - - - @GetMapping("/test/generate_schema_description") public String testGenerateSchemaDescription(){ return neo4JUitilityService.generateSchemaDescription().toString(); } - } diff --git a/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java b/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java index 8f52ada..34fca54 100644 --- a/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java +++ b/src/main/java/com/olympus/hermione/dto/ScenarioOutput.java @@ -11,4 +11,6 @@ public class ScenarioOutput { private String scenarioExecution_id; private String status; private String message; + private String currentStepDescription; + private String currentStepId; } \ No newline at end of file diff --git a/src/main/java/com/olympus/hermione/dto/aientity/Change.java b/src/main/java/com/olympus/hermione/dto/aientity/Change.java new file mode 100644 index 0000000..06d9f14 --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/aientity/Change.java @@ -0,0 +1,93 @@ +package com.olympus.hermione.dto.aientity; + +import com.fasterxml.jackson.annotation.JsonProperty; + +public class Change { + private String filename; + private String classname; + private String methodname; + @JsonProperty("new_code") + private String newCode; + @JsonProperty("previous_code") + private String previousCode; + @JsonProperty("change_description") + private String changeDescription; + + // Constructors + public Change() {} + + + public Change(String filename, String classname, String methodname, String newCode, String previousCode, String changeDescription) { + this.filename = filename; + this.classname = classname; + this.methodname = methodname; + this.newCode = newCode; + this.previousCode = previousCode; + this.changeDescription = changeDescription; + } + + // Getters and Setters + + + + public String getFilename() { + return filename; + } + + public void setFilename(String filename) { + this.filename = filename; + } + + public String getClassname() { + return classname; + } + + public void setClassname(String classname) { + this.classname = classname; + } + + public String getMethodname() { + return methodname; + } + + public void setMethodname(String methodname) { + this.methodname = methodname; + } + + public String getNewCode() { + return newCode; + } + + public void setNewCode(String newCode) { + this.newCode = newCode; + } + + public String getPreviousCode() { + return previousCode; + } + + public void setPreviousCode(String previousCode) { + this.previousCode = previousCode; + } + + public String getChangeDescription() { + return changeDescription; + } + + public void setChangeDescription(String changeDescription) { + this.changeDescription = changeDescription; + } + + // toString Method + @Override + public String toString() { + return "Change{" + + "filename='" + filename + '\'' + + ", classname='" + classname + '\'' + + ", methodname='" + methodname + '\'' + + ", newCode='" + newCode + '\'' + + ", previousCode='" + previousCode + '\'' + + ", changeDescription='" + changeDescription + '\'' + + '}'; + } +} diff --git a/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java b/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java new file mode 100644 index 0000000..de0f2af --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/aientity/CiaOutputEntity.java @@ -0,0 +1,66 @@ +package com.olympus.hermione.dto.aientity; + +import java.util.List; + + +public class CiaOutputEntity { + private String application; + private String title; + private String description; + private List changes; + + // Constructors + public CiaOutputEntity() {} + + public CiaOutputEntity(String application, String title, String description, List changes) { + this.application = application; + this.title = title; + this.description = description; + this.changes = changes; + } + + public void setApplication(String application) { + this.application = application; + } + + public String getApplication() { + return application; + } + + // Getters and Setters + public String getTitle() { + return title; + } + + public void setTitle(String title) { + this.title = title; + } + + public String getDescription() { + return description; + } + + public void setDescription(String description) { + this.description = description; + } + + public List getChanges() { + return changes; + } + + public void setChanges(List changes) { + this.changes = changes; + } + + // toString Method + @Override + public String toString() { + return "CodeChangeRequest{" + + "title='" + title + '\'' + + ", description='" + description + '\'' + + ", changes=" + changes ; + + + + } +} diff --git a/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java b/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java new file mode 100644 index 0000000..8f74b81 --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/external/CodeRagResponse.java @@ -0,0 +1,16 @@ +package com.olympus.hermione.dto.external; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class CodeRagResponse { + + private String fullyQualifiedName; + private String code; + private String description; + private String similarityScore; + private String documentId; + private String codeType; +} diff --git a/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java b/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java new file mode 100644 index 0000000..51e069a --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/external/SimilaritySearchCodeInput.java @@ -0,0 +1,23 @@ +package com.olympus.hermione.dto.external; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class SimilaritySearchCodeInput { + + private String query; + private String topK; + private String similarityThreshold; + private String filterExpression; + +} + + +/* + * .withQuery(query) + .withTopK(3) + .withSimilarityThreshold(0.8) + .withFilterExpression(b.eq("doc_type","GeneratedMethodDoc").build()) + */ \ No newline at end of file diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java index 43ce81a..979216f 100644 --- a/src/main/java/com/olympus/hermione/models/Scenario.java +++ b/src/main/java/com/olympus/hermione/models/Scenario.java @@ -19,4 +19,5 @@ public class Scenario { private List steps; private List inputs; + private boolean useChatMemory=false; } diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index eb18abf..475fdf2 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -25,6 +25,7 @@ public class ScenarioExecution { private String currentStepId; private String nextStepId; + private String currentStepDescription; private ScenarioExecutionInput scenarioExecutionInput; diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index e5b2fdd..2f1ea67 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -1,14 +1,20 @@ package com.olympus.hermione.services; -import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Optional; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; +import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.memory.InMemoryChatMemory; import org.springframework.ai.chat.model.ChatModel; + import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.cloud.client.discovery.DiscoveryClient; import org.springframework.scheduling.annotation.Async; import org.springframework.stereotype.Service; @@ -19,9 +25,11 @@ import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.models.ScenarioStep; import com.olympus.hermione.repository.ScenarioExecutionRepository; import com.olympus.hermione.repository.ScenarioRepository; +import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; +import com.olympus.hermione.stepSolvers.SourceCodeRagSolver; import com.olympus.hermione.stepSolvers.StepSolver; import org.neo4j.driver.Driver; @@ -38,6 +46,9 @@ public class ScenarioExecutionService { @Autowired VectorStore vectorStore; + + @Autowired + DiscoveryClient discoveryClient; @Autowired ChatModel chatModel; @@ -48,6 +59,8 @@ public class ScenarioExecutionService { @Autowired Neo4JUitilityService neo4JUitilityService; + private ChatClient chatClient; + private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); @@ -57,6 +70,7 @@ public class ScenarioExecutionService { if(o_scenarioExecution.isPresent()){ ScenarioExecution scenarioExecution = o_scenarioExecution.get(); + scenarioExecution.setCurrentStepId(scenarioExecution.getCurrentStepId()); if(scenarioExecution.getExecSharedMap().get("scenario_output")!=null){ scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); @@ -85,12 +99,11 @@ public class ScenarioExecutionService { if(o_scenarioExecution.isPresent()){ ScenarioExecution scenarioExecution = o_scenarioExecution.get(); - HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs(); - Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId()); Scenario scenario = o_scenario.get(); + logger.info("Start execution of scenario: " + scenario.getName()); HashMap execSharedMap = new HashMap(); @@ -98,6 +111,31 @@ public class ScenarioExecutionService { scenarioExecution.setExecSharedMap(execSharedMap); scenarioExecutionRepository.save(scenarioExecution); + //TODO: Chatmodel should be initialized with the scenario specific infos + + if(scenario.isUseChatMemory()){ + logger.info("Initializing chatClient with chat-memory advisor"); + + ChatMemory chatMemory = new InMemoryChatMemory(); + chatClient = ChatClient.builder(chatModel) + .defaultAdvisors( + new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor + new SimpleLoggerAdvisor() + ) + .build(); + + }else{ + logger.info("Initializing chatClient with simple logger advisor"); + + chatClient = ChatClient.builder(chatModel) + .defaultAdvisors( + new SimpleLoggerAdvisor() + ) + .build(); + } + + + List steps = scenario.getSteps(); String startStepId=scenario.getStartWithStepId(); @@ -115,10 +153,56 @@ public class ScenarioExecutionService { } + + + + private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){ + logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); + StepSolver solver=new StepSolver(); + switch (step.getType()) { + case "ADVANCED_QUERY_AI": + solver = new AdvancedAIPromptSolver(); + break; + case "RAG_QUERY": + solver = new BasicQueryRagSolver(); + break; + case "SIMPLE_QUERY_AI": + solver = new BasicAIPromptSolver(); + break; + case "QUERY_GRAPH": + solver = new QueryNeo4JSolver(); + break; + case "RAG_SOURCE_CODE": + solver = new SourceCodeRagSolver(); + break; + + default: + break; + } + + + logger.info("Initializing step: " + step.getStepId()); + solver.init(step, scenarioExecution, + vectorStore,chatModel,chatClient,graphDriver,neo4JUitilityService + ,discoveryClient); + + logger.info("Solving step: " + step.getStepId()); + scenarioExecution.setCurrentStepId(step.getStepId()); + scenarioExecution.setCurrentStepDescription(step.getName()); + + ScenarioExecution scenarioExecutionNew = solver.solveStep(); + + logger.info("Step solved: " + step.getStepId()); + logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); + + scenarioExecutionRepository.save(scenarioExecutionNew); + + } + + public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){ String scenarioId = scenarioExecutionInput.getScenario_id(); - HashMap inputs = scenarioExecutionInput.getInputs(); ScenarioOutput scenarioOutput = new ScenarioOutput(); Optional o_scenario = scenarioRepository.findById(scenarioId); @@ -134,94 +218,11 @@ public class ScenarioExecutionService { scenarioExecutionRepository.save(scenarioExecution); scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); - scenarioOutput.setStatus("STARTED"); + scenarioOutput.setStatus("IN_PROGRESS"); } return scenarioOutput; } - - public ScenarioOutput executeScenario(ScenarioExecutionInput scenarioExecutionInput){ - - String scenarioId = scenarioExecutionInput.getScenario_id(); - HashMap inputs = scenarioExecutionInput.getInputs(); - ScenarioOutput scenarioOutput = new ScenarioOutput(); - - Optional o_scenario = scenarioRepository.findById(scenarioId); - - if(o_scenario.isPresent()){ - Scenario scenario = o_scenario.get(); - logger.info("Executing scenario: " + scenario.getName()); - - ScenarioExecution scenarioExecution = new ScenarioExecution(); - scenarioExecution.setScenario(scenario); - - HashMap execSharedMap = new HashMap(); - execSharedMap.put("user_input", inputs); - scenarioExecution.setExecSharedMap(execSharedMap); - scenarioExecutionRepository.save(scenarioExecution); - - List steps = scenario.getSteps(); - String startStepId=scenario.getStartWithStepId(); - - ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); - - executeScenarioStep(startStep, scenarioExecution); - - while (scenarioExecution.getNextStepId()!=null) { - ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); - executeScenarioStep(step, scenarioExecution); - } - - scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); - scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); - scenarioOutput.setStatus("OK"); - - return scenarioOutput; - - }else{ - logger.error("Scenario not found with id: " + scenarioId); - scenarioOutput.setScenarioExecution_id(null); - scenarioOutput.setStatus("ERROR"); - scenarioOutput.setMessage("Scenario not found"); - } - - return scenarioOutput; - } - - - - private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){ - logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); - StepSolver solver=new StepSolver(); - switch (step.getType()) { - case "RAG_QUERY": - solver = new BasicQueryRagSolver(); - break; - case "SIMPLE_QUERY_AI": - solver = new BasicAIPromptSolver(); - break; - case "QUERY_GRAPH": - solver = new QueryNeo4JSolver(); - break; - default: - break; - } - - logger.info("Initializing step: " + step.getStepId()); - solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService); - - logger.info("Solving step: " + step.getStepId()); - scenarioExecution.setCurrentStepId(step.getStepId()); - - ScenarioExecution scenarioExecutionNew = solver.solveStep(); - - logger.info("Step solved: " + step.getStepId()); - logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); - - scenarioExecutionRepository.save(scenarioExecutionNew); - - } - } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java new file mode 100644 index 0000000..75ae268 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java @@ -0,0 +1,95 @@ +package com.olympus.hermione.stepSolvers; + +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.olympus.hermione.dto.aientity.CiaOutputEntity; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + + +public class AdvancedAIPromptSolver extends StepSolver { + + private String qai_system_prompt_template; + private String qai_user_input; + private String qai_output_variable; + private boolean qai_load_graph_schema=false; + private String qai_output_entityType; + + + Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + + private void loadParameters(){ + logger.info("Loading parameters"); + + if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){ + this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema")); + logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema); + } + + if(this.qai_load_graph_schema){ + logger.info("Loading graph schema"); + this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString()); + } + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template")); + this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input")); + + this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable"); + this.qai_output_entityType = (String) this.step.getAttributes().get("qai_output_entityType"); + + //TODO: Add memory ID attribute to have the possibility of multiple conversations + + } + + @Override + public ScenarioExecution solveStep(){ + + System.out.println("Solving step: " + this.step.getName()); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + + loadParameters(); + + String userText = this.qai_user_input; + + Message userMessage = new UserMessage(userText); + Message systemMessage = new SystemMessage(this.qai_system_prompt_template); + + CallResponseSpec resp = chatClient.prompt() + .messages(userMessage,systemMessage) + .advisors(advisor -> advisor.param("chat_memory_conversation_id", this.scenarioExecution.getId()) + .param("chat_memory_response_size", 100)) + .call(); + + + if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){ + logger.info("Output is of type CiaOutputEntity"); + CiaOutputEntity ouputEntity = resp.entity(CiaOutputEntity.class); + try { + ObjectMapper objectMapper = new ObjectMapper(); + String jsonOutput = objectMapper.writeValueAsString(ouputEntity); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, jsonOutput); + } catch (JsonProcessingException e) { + logger.error("Error converting output entity to JSON", e); + } + }else{ + + String output = resp.content(); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + } + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +} diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index 137471a..524d6fc 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -53,10 +53,7 @@ public class BasicAIPromptSolver extends StepSolver { loadParameters(); - - String userText = this.qai_user_input; - - Message userMessage = new UserMessage(userText); + Message userMessage = new UserMessage(this.qai_user_input); Message systemMessage = new SystemMessage(this.qai_system_prompt_template); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java new file mode 100644 index 0000000..dbacaea --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/SourceCodeRagSolver.java @@ -0,0 +1,99 @@ +package com.olympus.hermione.stepSolvers; + + + +import org.slf4j.LoggerFactory; +import org.springframework.cloud.client.ServiceInstance; +import org.springframework.web.client.RestTemplate; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.ObjectWriter; +import com.olympus.hermione.dto.external.CodeRagResponse; +import com.olympus.hermione.dto.external.SimilaritySearchCodeInput; +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class SourceCodeRagSolver extends StepSolver { + + Logger logger = (Logger) LoggerFactory.getLogger(SourceCodeRagSolver.class); + + + private String query; + private String rag_filter; + private int topk; + private double threshold; + private String outputField; + + + + private void loadParameters(){ + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input")); + + this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter")); + + if(this.step.getAttributes().containsKey("rag_topk")){ + this.topk = (int) this.step.getAttributes().get("rag_topk"); + }else{ + this.topk = 4; + } + + if(this.step.getAttributes().containsKey("rag_threshold")){ + this.threshold = (double) this.step.getAttributes().get("rag_threshold"); + }else{ + this.threshold = 0.8; + } + + this.outputField = (String) this.step.getAttributes().get("rag_output_variable"); + } + + private void logParameters(){ + logger.info("query: " + this.query); + logger.info("rag_filter: " + this.rag_filter); + logger.info("topk: " + this.topk); + logger.info("threshold: " + this.threshold); + logger.info("outputField: " + this.outputField); + } + + + @Override + public ScenarioExecution solveStep() { + + logger.info("Solving step: " + this.step.getName()); + loadParameters(); + logParameters(); + + SimilaritySearchCodeInput similaritySearchCodeInput = new SimilaritySearchCodeInput(); + similaritySearchCodeInput.setQuery(this.query); + similaritySearchCodeInput.setTopK(String.valueOf(this.topk)); + similaritySearchCodeInput.setSimilarityThreshold(String.valueOf(this.threshold)); + similaritySearchCodeInput.setFilterExpression(this.rag_filter); + + ServiceInstance serviceInstance = discoveryClient.getInstances("java-source-code-service").get(0); + RestTemplate restTemplate = new RestTemplate(); + + CodeRagResponse[] ragresponse = + restTemplate.postForEntity(serviceInstance.getUri() + "/similarity-search-code", + similaritySearchCodeInput,CodeRagResponse[].class ).getBody(); + + String code=""; + for (CodeRagResponse codeRagResponse : ragresponse) { + code += "SOURCE CODE OF "+codeRagResponse.getCodeType()+" : " + codeRagResponse.getFullyQualifiedName() +" \n"; + code += codeRagResponse.getCode() + "\n"; + code += "-----------------------------------\n"; + } + + this.scenarioExecution.getExecSharedMap().put(this.outputField, code); + + //concatenate the content of the results into a single string + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + return this.scenarioExecution; + } + +} diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java index 40f82c3..e18f343 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java @@ -1,8 +1,13 @@ package com.olympus.hermione.stepSolvers; import org.neo4j.driver.Driver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.cloud.client.discovery.DiscoveryClient; + import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.models.ScenarioStep; import com.olympus.hermione.services.Neo4JUitilityService; @@ -16,11 +21,14 @@ public class StepSolver { protected ChatModel chatModel; protected Driver graphDriver; protected Neo4JUitilityService neo4JUitilityService; + protected ChatClient chatClient; + protected DiscoveryClient discoveryClient; + + private Logger logger = LoggerFactory.getLogger(StepSolver.class); public ScenarioExecution solveStep(){ - System.out.println("Solving step: " + this.step.getName()); - + logger.info("Solving step: " + this.step.getName()); this.scenarioExecution.setCurrentStepId(this.step.getStepId()); this.scenarioExecution.setNextStepId(this.step.getNextStepId()); @@ -28,13 +36,17 @@ public class StepSolver { }; - public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){ + public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,ChatClient chatClient, Driver graphDriver,Neo4JUitilityService neo4JUitilityService, + DiscoveryClient discoveryClient){ + logger.info("Initializing StepSolver"); this.scenarioExecution = scenarioExecution; + this.chatClient = chatClient; this.step = step; this.vectorStore = vectorStore; this.chatModel = chatModel; this.graphDriver = graphDriver; this.neo4JUitilityService = neo4JUitilityService; - System.out.println("Initializing StepSolver"); + this.discoveryClient = discoveryClient; + logger.info("StepSolver initialized"); } } \ No newline at end of file diff --git a/src/main/java/com/olympus/hermione/utility/AttributeParser.java b/src/main/java/com/olympus/hermione/utility/AttributeParser.java index 096ba6e..0bfd3aa 100644 --- a/src/main/java/com/olympus/hermione/utility/AttributeParser.java +++ b/src/main/java/com/olympus/hermione/utility/AttributeParser.java @@ -16,8 +16,10 @@ public class AttributeParser { ST template = new ST(attribute_value, '$', '$'); template.add("shared", this.scenario_execution.getExecSharedMap()); + String output = template.render(); - return template.render(); + + return output; } } diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index ce54ede..6557520 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -16,4 +16,20 @@ spring.main.allow-circular-references=true neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687 neo4j.username=neo4j -neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ \ No newline at end of file +neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ + +spring.neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687 +spring.neo4j.authentication.username=neo4j +spring.neo4j.authentication.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ + + +spring.ai.vectorstore.neo4j.database-name:neo4j +spring.ai.vectorstore.neo4j.initialize-schema:true +spring.ai.vectorstore.neo4j.label:Document +spring.ai.vectorstore.neo4j.embedding-property:embedding +spring.ai.vectorstore.neo4j.index-name:spring-ai-document-index + + +spring.main.allow-bean-definition-overriding=true +logging.level.org.springframework.ai.chat.client.advisor=DEBUG + diff --git a/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java b/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java index f1e0ef4..687decb 100644 --- a/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java +++ b/src/test/java/com/olympus/hermione/Neo4jSchemaDescriptionTest.java @@ -1,14 +1,11 @@ package com.olympus.hermione; import static org.junit.jupiter.api.Assertions.assertTrue; -import org.junit.jupiter.api.*; -import static org.junit.jupiter.api.Assertions.*; import com.google.gson.JsonArray; import com.google.gson.JsonObject; import com.olympus.hermione.services.Neo4JUitilityService; -import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.springframework.beans.factory.annotation.Autowired; @@ -24,7 +21,6 @@ public class Neo4jSchemaDescriptionTest { } - @Test void testGenerateSchemaDescription() {