chore: Add qai_load_graph_schema parameter to BasicAIPromptSolver

This commit is contained in:
andrea.terzani
2024-08-13 09:04:46 +02:00
parent c8c7c0589b
commit 52df0c613a
5 changed files with 43 additions and 26 deletions

View File

@@ -3,6 +3,7 @@ package com.olympus.hermione.security.utility;
import org.springframework.stereotype.Component;
import io.jsonwebtoken.*;
import io.jsonwebtoken.io.Decoders;
import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.SignatureException;
import jakarta.servlet.http.HttpServletRequest;
@@ -18,6 +19,16 @@ import java.util.Date;
public class JwtTokenProvider {
Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512);
public static final String SECRET = "5367566B59703373367639792F423F4528482B4D6251655468576D5A71347437";
private Key getSignKey() {
byte[] keyBytes = Decoders.BASE64.decode(SECRET);
return Keys.hmacShaKeyFor(keyBytes);
}
public String createToken(Authentication authentication) {
@@ -29,7 +40,7 @@ public class JwtTokenProvider {
.setSubject(userDetails.getUsername())
.setIssuedAt(new Date())
.setExpiration(expiryDate)
.signWith(SignatureAlgorithm.HS512, key)
.signWith(getSignKey(), SignatureAlgorithm.HS256)
.compact();
}
@@ -47,7 +58,7 @@ public class JwtTokenProvider {
public boolean validateToken(String token) {
try {
Jwts.parser().setSigningKey(key).parseClaimsJws(token);
Jwts.parser().setSigningKey(getSignKey()).parseClaimsJws(token);
return true;
} catch (MalformedJwtException ex) {
log.error("Invalid JWT token");
@@ -67,7 +78,7 @@ public class JwtTokenProvider {
public String getUsername(String token) {
return Jwts.parser()
.setSigningKey(key)
.setSigningKey(getSignKey())
.parseClaimsJws(token)
.getBody()
.getSubject();

View File

@@ -44,14 +44,12 @@ public class ScenarioExecutionService {
@Autowired
Driver graphDriver;
@Autowired
Neo4JUitilityService neo4JUitilityService;
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
/*
public String executeScenarioNew(ScenarioExecutionInput scenarioExecutionInput){
}
*/
public String executeScenario(String scenarioId, String input){
@@ -117,6 +115,7 @@ public class ScenarioExecutionService {
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
scenarioOutput.setStatus("OK");
return scenarioOutput;
}else{
@@ -132,7 +131,7 @@ public class ScenarioExecutionService {
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
logger.info("Executing step: " + step.getName());
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
StepSolver solver=new StepSolver();
switch (step.getType()) {
case "RAG_QUERY":
@@ -148,7 +147,10 @@ public class ScenarioExecutionService {
break;
}
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver);
logger.info("Initializing step: " + step.getName());
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
logger.info("Solving step: " + step.getName());
//must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep();

View File

@@ -18,15 +18,25 @@ public class BasicAIPromptSolver extends StepSolver {
private String qai_system_prompt_template;
private String qai_user_input;
private String qai_output_variable;
private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){
logger.info("Loading parameters");
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
}
if(this.qai_load_graph_schema){
logger.info("Loading graph schema");
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
}
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));

View File

@@ -1,15 +1,8 @@
package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.List;
import org.neo4j.driver.Driver;
import org.neo4j.driver.Result;
import org.neo4j.driver.Session;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.beans.factory.annotation.Autowired;
import org.neo4j.driver.Record;
import org.neo4j.driver.Value;
@@ -26,22 +19,18 @@ public class QueryNeo4JSolver extends StepSolver {
private String cypherQuery;
private String outputField;
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
}
private void logParameters(){
logger.info("query: " + this.cypherQuery);
logger.info("cypherQuery: " + this.cypherQuery);
logger.info("outputField: " + this.outputField);
}
@@ -52,6 +41,7 @@ public class QueryNeo4JSolver extends StepSolver {
logger.info("Solving step: " + this.step.getName());
loadParameters();
logParameters();
try (Session session = graphDriver.session()) {

View File

@@ -5,6 +5,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.services.Neo4JUitilityService;
public class StepSolver {
@@ -14,17 +15,20 @@ public class StepSolver {
protected VectorStore vectorStore;
protected ChatModel chatModel;
protected Driver graphDriver;
protected Neo4JUitilityService neo4JUitilityService;
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
return this.scenarioExecution;
};
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
this.scenarioExecution = scenarioExecution;
this.step = step;
this.vectorStore = vectorStore;
this.chatModel = chatModel;
this.graphDriver = graphDriver;
this.neo4JUitilityService = neo4JUitilityService;
System.out.println("Initializing StepSolver");
}
}