chore: Add qai_load_graph_schema parameter to BasicAIPromptSolver

This commit is contained in:
andrea.terzani
2024-08-13 09:04:46 +02:00
parent c8c7c0589b
commit 52df0c613a
5 changed files with 43 additions and 26 deletions

View File

@@ -3,6 +3,7 @@ package com.olympus.hermione.security.utility;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import io.jsonwebtoken.*; import io.jsonwebtoken.*;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys; import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.SignatureException; import io.jsonwebtoken.security.SignatureException;
import jakarta.servlet.http.HttpServletRequest; import jakarta.servlet.http.HttpServletRequest;
@@ -19,6 +20,16 @@ public class JwtTokenProvider {
Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512); Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512);
public static final String SECRET = "5367566B59703373367639792F423F4528482B4D6251655468576D5A71347437";
private Key getSignKey() {
byte[] keyBytes = Decoders.BASE64.decode(SECRET);
return Keys.hmacShaKeyFor(keyBytes);
}
public String createToken(Authentication authentication) { public String createToken(Authentication authentication) {
UserDetails userDetails = (UserDetails) authentication.getPrincipal(); UserDetails userDetails = (UserDetails) authentication.getPrincipal();
@@ -29,7 +40,7 @@ public class JwtTokenProvider {
.setSubject(userDetails.getUsername()) .setSubject(userDetails.getUsername())
.setIssuedAt(new Date()) .setIssuedAt(new Date())
.setExpiration(expiryDate) .setExpiration(expiryDate)
.signWith(SignatureAlgorithm.HS512, key) .signWith(getSignKey(), SignatureAlgorithm.HS256)
.compact(); .compact();
} }
@@ -47,7 +58,7 @@ public class JwtTokenProvider {
public boolean validateToken(String token) { public boolean validateToken(String token) {
try { try {
Jwts.parser().setSigningKey(key).parseClaimsJws(token); Jwts.parser().setSigningKey(getSignKey()).parseClaimsJws(token);
return true; return true;
} catch (MalformedJwtException ex) { } catch (MalformedJwtException ex) {
log.error("Invalid JWT token"); log.error("Invalid JWT token");
@@ -67,7 +78,7 @@ public class JwtTokenProvider {
public String getUsername(String token) { public String getUsername(String token) {
return Jwts.parser() return Jwts.parser()
.setSigningKey(key) .setSigningKey(getSignKey())
.parseClaimsJws(token) .parseClaimsJws(token)
.getBody() .getBody()
.getSubject(); .getSubject();

View File

@@ -44,14 +44,12 @@ public class ScenarioExecutionService {
@Autowired @Autowired
Driver graphDriver; Driver graphDriver;
@Autowired
Neo4JUitilityService neo4JUitilityService;
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
/*
public String executeScenarioNew(ScenarioExecutionInput scenarioExecutionInput){
}
*/
public String executeScenario(String scenarioId, String input){ public String executeScenario(String scenarioId, String input){
@@ -117,6 +115,7 @@ public class ScenarioExecutionService {
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
scenarioOutput.setStatus("OK"); scenarioOutput.setStatus("OK");
return scenarioOutput; return scenarioOutput;
}else{ }else{
@@ -132,7 +131,7 @@ public class ScenarioExecutionService {
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){ private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
logger.info("Executing step: " + step.getName()); logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
StepSolver solver=new StepSolver(); StepSolver solver=new StepSolver();
switch (step.getType()) { switch (step.getType()) {
case "RAG_QUERY": case "RAG_QUERY":
@@ -148,7 +147,10 @@ public class ScenarioExecutionService {
break; break;
} }
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver); logger.info("Initializing step: " + step.getName());
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
logger.info("Solving step: " + step.getName());
//must add try catch in order to catch exceptions and step execution errors //must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep(); ScenarioExecution scenarioExecutionNew = solver.solveStep();

View File

@@ -18,14 +18,24 @@ public class BasicAIPromptSolver extends StepSolver {
private String qai_system_prompt_template; private String qai_system_prompt_template;
private String qai_user_input; private String qai_user_input;
private String qai_output_variable; private String qai_output_variable;
private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){ private void loadParameters(){
logger.info("Loading parameters"); logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
}
if(this.qai_load_graph_schema){
logger.info("Loading graph schema");
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
}
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template")); this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input")); this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));

View File

@@ -1,15 +1,8 @@
package com.olympus.hermione.stepSolvers; package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.List;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result; import org.neo4j.driver.Result;
import org.neo4j.driver.Session; import org.neo4j.driver.Session;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.neo4j.driver.Record; import org.neo4j.driver.Record;
import org.neo4j.driver.Value; import org.neo4j.driver.Value;
@@ -27,21 +20,17 @@ public class QueryNeo4JSolver extends StepSolver {
private String outputField; private String outputField;
private void loadParameters(){ private void loadParameters(){
logger.info("Loading parameters"); logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query")); this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable")); this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
} }
private void logParameters(){ private void logParameters(){
logger.info("query: " + this.cypherQuery); logger.info("cypherQuery: " + this.cypherQuery);
logger.info("outputField: " + this.outputField); logger.info("outputField: " + this.outputField);
} }
@@ -54,6 +43,7 @@ public class QueryNeo4JSolver extends StepSolver {
logParameters(); logParameters();
try (Session session = graphDriver.session()) { try (Session session = graphDriver.session()) {
String query = this.cypherQuery; String query = this.cypherQuery;

View File

@@ -5,6 +5,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore; import org.springframework.ai.vectorstore.VectorStore;
import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep; import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.services.Neo4JUitilityService;
public class StepSolver { public class StepSolver {
@@ -14,17 +15,20 @@ public class StepSolver {
protected VectorStore vectorStore; protected VectorStore vectorStore;
protected ChatModel chatModel; protected ChatModel chatModel;
protected Driver graphDriver; protected Driver graphDriver;
protected Neo4JUitilityService neo4JUitilityService;
public ScenarioExecution solveStep(){ public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName()); System.out.println("Solving step: " + this.step.getName());
return this.scenarioExecution; return this.scenarioExecution;
}; };
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){ public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
this.scenarioExecution = scenarioExecution; this.scenarioExecution = scenarioExecution;
this.step = step; this.step = step;
this.vectorStore = vectorStore; this.vectorStore = vectorStore;
this.chatModel = chatModel; this.chatModel = chatModel;
this.graphDriver = graphDriver; this.graphDriver = graphDriver;
this.neo4JUitilityService = neo4JUitilityService;
System.out.println("Initializing StepSolver"); System.out.println("Initializing StepSolver");
} }
} }