Added Neo4j Step Solver

This commit is contained in:
andrea.terzani
2024-08-10 22:58:29 +02:00
parent afcdb01220
commit 090ecabf6a
7 changed files with 155 additions and 5 deletions

View File

@@ -47,7 +47,12 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>
<version>5.8.0</version>
</dependency>
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>

View File

@@ -0,0 +1,31 @@
package com.olympus.hermione.config;
import org.neo4j.driver.Driver;
import org.neo4j.driver.GraphDatabase;
import org.neo4j.driver.Session;
import org.neo4j.driver.SessionConfig;
import org.neo4j.driver.AuthTokens;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
@Configuration
public class Neo4jConfig {
@Value("${neo4j.uri}")
private String uri;
@Value("${neo4j.username}")
private String username;
@Value("${neo4j.password}")
private String password;
@Bean
public Driver neo4jDriver() {
Driver driver = GraphDatabase.driver(uri, AuthTokens.basic(username, password));
return driver;
}
}

View File

@@ -0,0 +1,6 @@
package com.olympus.hermione.controllers;
public class CanvasController {
}

View File

@@ -20,8 +20,10 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
import com.olympus.hermione.stepSolvers.StepSolver;
import org.neo4j.driver.Driver;
import org.slf4j.Logger;
@Service
@@ -39,6 +41,10 @@ public class ScenarioExecutionService {
@Autowired
ChatModel chatModel;
@Autowired
Driver graphDriver;
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
/*
@@ -135,11 +141,14 @@ public class ScenarioExecutionService {
case "SIMPLE_QUERY_AI":
solver = new BasicAIPromptSolver();
break;
case "QUERY_GRAPH":
solver = new QueryNeo4JSolver();
break;
default:
break;
}
solver.init(step, scenarioExecution, vectorStore,chatModel);
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver);
//must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep();

View File

@@ -0,0 +1,92 @@
package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.List;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.neo4j.driver.Record;
import org.neo4j.driver.Value;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class QueryNeo4JSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(QueryNeo4JSolver.class);
private String cypherQuery;
private String outputField;
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
}
private void logParameters(){
logger.info("query: " + this.cypherQuery);
logger.info("outputField: " + this.outputField);
}
@Override
public ScenarioExecution solveStep(){
logger.info("Solving step: " + this.step.getName());
loadParameters();
logParameters();
StringBuilder resultString = new StringBuilder();
try (Session session = graphDriver.session()) {
String query = this.cypherQuery;
Result result = session.run(query);
while(result.hasNext()){
Record record = result.next();
Value nodeValue = record.get("sourceCode");
String nodeString = nodeValue.toString();
if(nodeString.isEmpty()){
nodeValue = record.get("output");
nodeString = nodeValue.toString();
}
if (resultString.length() > 0) {
resultString.append(", ");
}
resultString.append(nodeString);
}
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString.toString());
} catch (Exception e) {
logger.error("Error executing query: " + e.getMessage());
this.scenarioExecution.getExecSharedMap().put(this.outputField, "Error executing query: " + e.getMessage());
}
return this.scenarioExecution;
}
}

View File

@@ -1,5 +1,6 @@
package com.olympus.hermione.stepSolvers;
import org.neo4j.driver.Driver;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
import com.olympus.hermione.models.ScenarioExecution;
@@ -12,16 +13,18 @@ public class StepSolver {
protected ScenarioStep step;
protected VectorStore vectorStore;
protected ChatModel chatModel;
protected Driver graphDriver;
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
return this.scenarioExecution;
};
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel){
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){
this.scenarioExecution = scenarioExecution;
this.step = step;
this.vectorStore = vectorStore;
this.chatModel = chatModel;
this.graphDriver = graphDriver;
System.out.println("Initializing StepSolver");
}
}

View File

@@ -11,5 +11,9 @@ spring.ai.vectorstore.mongodb.collection-name=vector_store
spring.ai.vectorstore.mongodb.initialize-schema=false
spring.ai.openai.api-key=sk-proj-j3TFJ0h348DIzMrYYfyUT3BlbkFJjk4HMc8A2ux2Asg8Y7H1
spring.ai.openai.chat.options.model=gpt-4o-mini
spring.main.allow-circular-references=true
spring.main.allow-circular-references=true
neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687
neo4j.username=neo4j
neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ