diff --git a/pom.xml b/pom.xml index f14e3f8..89abbc7 100644 --- a/pom.xml +++ b/pom.xml @@ -47,7 +47,12 @@ org.springframework.ai spring-ai-openai-spring-boot-starter - + + org.neo4j.driver + neo4j-java-driver + 5.8.0 + + org.springframework.boot spring-boot-starter-test diff --git a/src/main/java/com/olympus/hermione/config/Neo4jConfig.java b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java new file mode 100644 index 0000000..9f0d8f2 --- /dev/null +++ b/src/main/java/com/olympus/hermione/config/Neo4jConfig.java @@ -0,0 +1,31 @@ +package com.olympus.hermione.config; + + +import org.neo4j.driver.Driver; +import org.neo4j.driver.GraphDatabase; +import org.neo4j.driver.Session; +import org.neo4j.driver.SessionConfig; +import org.neo4j.driver.AuthTokens; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class Neo4jConfig { + + @Value("${neo4j.uri}") + private String uri; + @Value("${neo4j.username}") + private String username; + @Value("${neo4j.password}") + private String password; + + + @Bean + public Driver neo4jDriver() { + Driver driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password)); + return driver; + } +} + + diff --git a/src/main/java/com/olympus/hermione/controllers/CanvasController.java b/src/main/java/com/olympus/hermione/controllers/CanvasController.java new file mode 100644 index 0000000..b88b9f2 --- /dev/null +++ b/src/main/java/com/olympus/hermione/controllers/CanvasController.java @@ -0,0 +1,6 @@ +package com.olympus.hermione.controllers; + + +public class CanvasController { + +} diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 511f4e7..f7ccd5c 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -20,8 +20,10 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository; import com.olympus.hermione.repository.ScenarioRepository; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; +import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; import com.olympus.hermione.stepSolvers.StepSolver; +import org.neo4j.driver.Driver; import org.slf4j.Logger; @Service @@ -39,6 +41,10 @@ public class ScenarioExecutionService { @Autowired ChatModel chatModel; + @Autowired + Driver graphDriver; + + private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); /* @@ -135,11 +141,14 @@ public class ScenarioExecutionService { case "SIMPLE_QUERY_AI": solver = new BasicAIPromptSolver(); break; + case "QUERY_GRAPH": + solver = new QueryNeo4JSolver(); + break; default: break; } - solver.init(step, scenarioExecution, vectorStore,chatModel); + solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver); //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java new file mode 100644 index 0000000..f858b61 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java @@ -0,0 +1,92 @@ +package com.olympus.hermione.stepSolvers; + +import java.util.ArrayList; +import java.util.List; + +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; +import org.neo4j.driver.Session; +import org.slf4j.LoggerFactory; +import org.springframework.ai.document.Document; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.beans.factory.annotation.Autowired; +import org.neo4j.driver.Record; +import org.neo4j.driver.Value; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class QueryNeo4JSolver extends StepSolver { + + Logger logger = (Logger) LoggerFactory.getLogger(QueryNeo4JSolver.class); + + + private String cypherQuery; + private String outputField; + + + + private void loadParameters(){ + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query")); + + this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable")); + + + } + + private void logParameters(){ + logger.info("query: " + this.cypherQuery); + logger.info("outputField: " + this.outputField); + } + + + @Override + public ScenarioExecution solveStep(){ + + logger.info("Solving step: " + this.step.getName()); + loadParameters(); + logParameters(); + StringBuilder resultString = new StringBuilder(); + + + try (Session session = graphDriver.session()) { + + String query = this.cypherQuery; + Result result = session.run(query); + while(result.hasNext()){ + Record record = result.next(); + + + Value nodeValue = record.get("sourceCode"); + String nodeString = nodeValue.toString(); + + if(nodeString.isEmpty()){ + nodeValue = record.get("output"); + nodeString = nodeValue.toString(); + } + + if (resultString.length() > 0) { + resultString.append(", "); + } + resultString.append(nodeString); + } + + this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString.toString()); + + } catch (Exception e) { + logger.error("Error executing query: " + e.getMessage()); + this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage()); + } + + return this.scenarioExecution; + } + + + +} diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java index 9c71a9a..67b3721 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java @@ -1,5 +1,6 @@ package com.olympus.hermione.stepSolvers; +import org.neo4j.driver.Driver; import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.vectorstore.VectorStore; import com.olympus.hermione.models.ScenarioExecution; @@ -12,16 +13,18 @@ public class StepSolver { protected ScenarioStep step; protected VectorStore vectorStore; protected ChatModel chatModel; - + protected Driver graphDriver; + public ScenarioExecution solveStep(){ System.out.println("Solving step: " + this.step.getName()); return this.scenarioExecution; }; - public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel){ + public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){ this.scenarioExecution = scenarioExecution; this.step = step; this.vectorStore = vectorStore; this.chatModel = chatModel; + this.graphDriver = graphDriver; System.out.println("Initializing StepSolver"); } } \ No newline at end of file diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 96ddf5e..ce54ede 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -11,5 +11,9 @@ spring.ai.vectorstore.mongodb.collection-name=vector_store spring.ai.vectorstore.mongodb.initialize-schema=false spring.ai.openai.api-key=sk-proj-j3TFJ0h348DIzMrYYfyUT3BlbkFJjk4HMc8A2ux2Asg8Y7H1 +spring.ai.openai.chat.options.model=gpt-4o-mini +spring.main.allow-circular-references=true -spring.main.allow-circular-references=true \ No newline at end of file +neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687 +neo4j.username=neo4j +neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ \ No newline at end of file