chore: Add qai_load_graph_schema parameter to BasicAIPromptSolver
This commit is contained in:
@@ -3,6 +3,7 @@ package com.olympus.hermione.security.utility;
|
|||||||
import org.springframework.stereotype.Component;
|
import org.springframework.stereotype.Component;
|
||||||
|
|
||||||
import io.jsonwebtoken.*;
|
import io.jsonwebtoken.*;
|
||||||
|
import io.jsonwebtoken.io.Decoders;
|
||||||
import io.jsonwebtoken.security.Keys;
|
import io.jsonwebtoken.security.Keys;
|
||||||
import io.jsonwebtoken.security.SignatureException;
|
import io.jsonwebtoken.security.SignatureException;
|
||||||
import jakarta.servlet.http.HttpServletRequest;
|
import jakarta.servlet.http.HttpServletRequest;
|
||||||
@@ -18,6 +19,16 @@ import java.util.Date;
|
|||||||
public class JwtTokenProvider {
|
public class JwtTokenProvider {
|
||||||
|
|
||||||
Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512);
|
Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512);
|
||||||
|
|
||||||
|
|
||||||
|
public static final String SECRET = "5367566B59703373367639792F423F4528482B4D6251655468576D5A71347437";
|
||||||
|
|
||||||
|
|
||||||
|
private Key getSignKey() {
|
||||||
|
byte[] keyBytes = Decoders.BASE64.decode(SECRET);
|
||||||
|
return Keys.hmacShaKeyFor(keyBytes);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
public String createToken(Authentication authentication) {
|
public String createToken(Authentication authentication) {
|
||||||
|
|
||||||
@@ -29,7 +40,7 @@ public class JwtTokenProvider {
|
|||||||
.setSubject(userDetails.getUsername())
|
.setSubject(userDetails.getUsername())
|
||||||
.setIssuedAt(new Date())
|
.setIssuedAt(new Date())
|
||||||
.setExpiration(expiryDate)
|
.setExpiration(expiryDate)
|
||||||
.signWith(SignatureAlgorithm.HS512, key)
|
.signWith(getSignKey(), SignatureAlgorithm.HS256)
|
||||||
.compact();
|
.compact();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -47,7 +58,7 @@ public class JwtTokenProvider {
|
|||||||
public boolean validateToken(String token) {
|
public boolean validateToken(String token) {
|
||||||
|
|
||||||
try {
|
try {
|
||||||
Jwts.parser().setSigningKey(key).parseClaimsJws(token);
|
Jwts.parser().setSigningKey(getSignKey()).parseClaimsJws(token);
|
||||||
return true;
|
return true;
|
||||||
} catch (MalformedJwtException ex) {
|
} catch (MalformedJwtException ex) {
|
||||||
log.error("Invalid JWT token");
|
log.error("Invalid JWT token");
|
||||||
@@ -67,7 +78,7 @@ public class JwtTokenProvider {
|
|||||||
public String getUsername(String token) {
|
public String getUsername(String token) {
|
||||||
|
|
||||||
return Jwts.parser()
|
return Jwts.parser()
|
||||||
.setSigningKey(key)
|
.setSigningKey(getSignKey())
|
||||||
.parseClaimsJws(token)
|
.parseClaimsJws(token)
|
||||||
.getBody()
|
.getBody()
|
||||||
.getSubject();
|
.getSubject();
|
||||||
|
|||||||
@@ -44,14 +44,12 @@ public class ScenarioExecutionService {
|
|||||||
@Autowired
|
@Autowired
|
||||||
Driver graphDriver;
|
Driver graphDriver;
|
||||||
|
|
||||||
|
@Autowired
|
||||||
|
Neo4JUitilityService neo4JUitilityService;
|
||||||
|
|
||||||
|
|
||||||
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
|
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
|
||||||
|
|
||||||
/*
|
|
||||||
public String executeScenarioNew(ScenarioExecutionInput scenarioExecutionInput){
|
|
||||||
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
|
|
||||||
public String executeScenario(String scenarioId, String input){
|
public String executeScenario(String scenarioId, String input){
|
||||||
@@ -117,6 +115,7 @@ public class ScenarioExecutionService {
|
|||||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||||
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
|
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
|
||||||
scenarioOutput.setStatus("OK");
|
scenarioOutput.setStatus("OK");
|
||||||
|
|
||||||
return scenarioOutput;
|
return scenarioOutput;
|
||||||
|
|
||||||
}else{
|
}else{
|
||||||
@@ -132,7 +131,7 @@ public class ScenarioExecutionService {
|
|||||||
|
|
||||||
|
|
||||||
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
|
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
|
||||||
logger.info("Executing step: " + step.getName());
|
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
|
||||||
StepSolver solver=new StepSolver();
|
StepSolver solver=new StepSolver();
|
||||||
switch (step.getType()) {
|
switch (step.getType()) {
|
||||||
case "RAG_QUERY":
|
case "RAG_QUERY":
|
||||||
@@ -148,7 +147,10 @@ public class ScenarioExecutionService {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver);
|
logger.info("Initializing step: " + step.getName());
|
||||||
|
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
|
||||||
|
|
||||||
|
logger.info("Solving step: " + step.getName());
|
||||||
|
|
||||||
//must add try catch in order to catch exceptions and step execution errors
|
//must add try catch in order to catch exceptions and step execution errors
|
||||||
ScenarioExecution scenarioExecutionNew = solver.solveStep();
|
ScenarioExecution scenarioExecutionNew = solver.solveStep();
|
||||||
|
|||||||
@@ -18,15 +18,25 @@ public class BasicAIPromptSolver extends StepSolver {
|
|||||||
private String qai_system_prompt_template;
|
private String qai_system_prompt_template;
|
||||||
private String qai_user_input;
|
private String qai_user_input;
|
||||||
private String qai_output_variable;
|
private String qai_output_variable;
|
||||||
|
private boolean qai_load_graph_schema=false;
|
||||||
|
|
||||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||||
|
|
||||||
private void loadParameters(){
|
private void loadParameters(){
|
||||||
logger.info("Loading parameters");
|
logger.info("Loading parameters");
|
||||||
|
|
||||||
|
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
|
||||||
|
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
|
||||||
|
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
|
||||||
|
}
|
||||||
|
|
||||||
|
if(this.qai_load_graph_schema){
|
||||||
|
logger.info("Loading graph schema");
|
||||||
|
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
|
||||||
|
}
|
||||||
|
|
||||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||||
|
|
||||||
|
|
||||||
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||||
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
||||||
|
|
||||||
|
|||||||
@@ -1,15 +1,8 @@
|
|||||||
package com.olympus.hermione.stepSolvers;
|
package com.olympus.hermione.stepSolvers;
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import org.neo4j.driver.Driver;
|
|
||||||
import org.neo4j.driver.Result;
|
import org.neo4j.driver.Result;
|
||||||
import org.neo4j.driver.Session;
|
import org.neo4j.driver.Session;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
import org.springframework.ai.document.Document;
|
|
||||||
import org.springframework.ai.vectorstore.SearchRequest;
|
|
||||||
import org.springframework.beans.factory.annotation.Autowired;
|
|
||||||
import org.neo4j.driver.Record;
|
import org.neo4j.driver.Record;
|
||||||
import org.neo4j.driver.Value;
|
import org.neo4j.driver.Value;
|
||||||
|
|
||||||
@@ -26,22 +19,18 @@ public class QueryNeo4JSolver extends StepSolver {
|
|||||||
private String cypherQuery;
|
private String cypherQuery;
|
||||||
private String outputField;
|
private String outputField;
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
private void loadParameters(){
|
private void loadParameters(){
|
||||||
logger.info("Loading parameters");
|
logger.info("Loading parameters");
|
||||||
|
|
||||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||||
|
|
||||||
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
|
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
|
||||||
|
|
||||||
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
|
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void logParameters(){
|
private void logParameters(){
|
||||||
logger.info("query: " + this.cypherQuery);
|
logger.info("cypherQuery: " + this.cypherQuery);
|
||||||
logger.info("outputField: " + this.outputField);
|
logger.info("outputField: " + this.outputField);
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,6 +41,7 @@ public class QueryNeo4JSolver extends StepSolver {
|
|||||||
logger.info("Solving step: " + this.step.getName());
|
logger.info("Solving step: " + this.step.getName());
|
||||||
loadParameters();
|
loadParameters();
|
||||||
logParameters();
|
logParameters();
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
try (Session session = graphDriver.session()) {
|
try (Session session = graphDriver.session()) {
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import org.springframework.ai.chat.model.ChatModel;
|
|||||||
import org.springframework.ai.vectorstore.VectorStore;
|
import org.springframework.ai.vectorstore.VectorStore;
|
||||||
import com.olympus.hermione.models.ScenarioExecution;
|
import com.olympus.hermione.models.ScenarioExecution;
|
||||||
import com.olympus.hermione.models.ScenarioStep;
|
import com.olympus.hermione.models.ScenarioStep;
|
||||||
|
import com.olympus.hermione.services.Neo4JUitilityService;
|
||||||
|
|
||||||
public class StepSolver {
|
public class StepSolver {
|
||||||
|
|
||||||
@@ -14,17 +15,20 @@ public class StepSolver {
|
|||||||
protected VectorStore vectorStore;
|
protected VectorStore vectorStore;
|
||||||
protected ChatModel chatModel;
|
protected ChatModel chatModel;
|
||||||
protected Driver graphDriver;
|
protected Driver graphDriver;
|
||||||
|
protected Neo4JUitilityService neo4JUitilityService;
|
||||||
|
|
||||||
|
|
||||||
public ScenarioExecution solveStep(){
|
public ScenarioExecution solveStep(){
|
||||||
System.out.println("Solving step: " + this.step.getName());
|
System.out.println("Solving step: " + this.step.getName());
|
||||||
return this.scenarioExecution;
|
return this.scenarioExecution;
|
||||||
};
|
};
|
||||||
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){
|
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
|
||||||
this.scenarioExecution = scenarioExecution;
|
this.scenarioExecution = scenarioExecution;
|
||||||
this.step = step;
|
this.step = step;
|
||||||
this.vectorStore = vectorStore;
|
this.vectorStore = vectorStore;
|
||||||
this.chatModel = chatModel;
|
this.chatModel = chatModel;
|
||||||
this.graphDriver = graphDriver;
|
this.graphDriver = graphDriver;
|
||||||
|
this.neo4JUitilityService = neo4JUitilityService;
|
||||||
System.out.println("Initializing StepSolver");
|
System.out.println("Initializing StepSolver");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Reference in New Issue
Block a user