chore: Add qai_load_graph_schema parameter to BasicAIPromptSolver
This commit is contained in:
@@ -3,6 +3,7 @@ package com.olympus.hermione.security.utility;
|
||||
import org.springframework.stereotype.Component;
|
||||
|
||||
import io.jsonwebtoken.*;
|
||||
import io.jsonwebtoken.io.Decoders;
|
||||
import io.jsonwebtoken.security.Keys;
|
||||
import io.jsonwebtoken.security.SignatureException;
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
@@ -18,6 +19,16 @@ import java.util.Date;
|
||||
public class JwtTokenProvider {
|
||||
|
||||
Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512);
|
||||
|
||||
|
||||
public static final String SECRET = "5367566B59703373367639792F423F4528482B4D6251655468576D5A71347437";
|
||||
|
||||
|
||||
private Key getSignKey() {
|
||||
byte[] keyBytes = Decoders.BASE64.decode(SECRET);
|
||||
return Keys.hmacShaKeyFor(keyBytes);
|
||||
}
|
||||
|
||||
|
||||
public String createToken(Authentication authentication) {
|
||||
|
||||
@@ -29,7 +40,7 @@ public class JwtTokenProvider {
|
||||
.setSubject(userDetails.getUsername())
|
||||
.setIssuedAt(new Date())
|
||||
.setExpiration(expiryDate)
|
||||
.signWith(SignatureAlgorithm.HS512, key)
|
||||
.signWith(getSignKey(), SignatureAlgorithm.HS256)
|
||||
.compact();
|
||||
}
|
||||
|
||||
@@ -47,7 +58,7 @@ public class JwtTokenProvider {
|
||||
public boolean validateToken(String token) {
|
||||
|
||||
try {
|
||||
Jwts.parser().setSigningKey(key).parseClaimsJws(token);
|
||||
Jwts.parser().setSigningKey(getSignKey()).parseClaimsJws(token);
|
||||
return true;
|
||||
} catch (MalformedJwtException ex) {
|
||||
log.error("Invalid JWT token");
|
||||
@@ -67,7 +78,7 @@ public class JwtTokenProvider {
|
||||
public String getUsername(String token) {
|
||||
|
||||
return Jwts.parser()
|
||||
.setSigningKey(key)
|
||||
.setSigningKey(getSignKey())
|
||||
.parseClaimsJws(token)
|
||||
.getBody()
|
||||
.getSubject();
|
||||
|
||||
@@ -44,14 +44,12 @@ public class ScenarioExecutionService {
|
||||
@Autowired
|
||||
Driver graphDriver;
|
||||
|
||||
@Autowired
|
||||
Neo4JUitilityService neo4JUitilityService;
|
||||
|
||||
|
||||
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
|
||||
|
||||
/*
|
||||
public String executeScenarioNew(ScenarioExecutionInput scenarioExecutionInput){
|
||||
|
||||
}
|
||||
*/
|
||||
|
||||
|
||||
public String executeScenario(String scenarioId, String input){
|
||||
@@ -117,6 +115,7 @@ public class ScenarioExecutionService {
|
||||
scenarioOutput.setScenarioExecution_id(scenarioExecution.getId());
|
||||
scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString());
|
||||
scenarioOutput.setStatus("OK");
|
||||
|
||||
return scenarioOutput;
|
||||
|
||||
}else{
|
||||
@@ -132,7 +131,7 @@ public class ScenarioExecutionService {
|
||||
|
||||
|
||||
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
|
||||
logger.info("Executing step: " + step.getName());
|
||||
logger.info("Start working on step: " + step.getName() + " with type: " + step.getType());
|
||||
StepSolver solver=new StepSolver();
|
||||
switch (step.getType()) {
|
||||
case "RAG_QUERY":
|
||||
@@ -148,7 +147,10 @@ public class ScenarioExecutionService {
|
||||
break;
|
||||
}
|
||||
|
||||
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver);
|
||||
logger.info("Initializing step: " + step.getName());
|
||||
solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService);
|
||||
|
||||
logger.info("Solving step: " + step.getName());
|
||||
|
||||
//must add try catch in order to catch exceptions and step execution errors
|
||||
ScenarioExecution scenarioExecutionNew = solver.solveStep();
|
||||
|
||||
@@ -18,15 +18,25 @@ public class BasicAIPromptSolver extends StepSolver {
|
||||
private String qai_system_prompt_template;
|
||||
private String qai_user_input;
|
||||
private String qai_output_variable;
|
||||
private boolean qai_load_graph_schema=false;
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
|
||||
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
|
||||
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
|
||||
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
|
||||
}
|
||||
|
||||
if(this.qai_load_graph_schema){
|
||||
logger.info("Loading graph schema");
|
||||
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
|
||||
}
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
|
||||
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
||||
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.neo4j.driver.Driver;
|
||||
import org.neo4j.driver.Result;
|
||||
import org.neo4j.driver.Session;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.neo4j.driver.Record;
|
||||
import org.neo4j.driver.Value;
|
||||
|
||||
@@ -26,22 +19,18 @@ public class QueryNeo4JSolver extends StepSolver {
|
||||
private String cypherQuery;
|
||||
private String outputField;
|
||||
|
||||
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query"));
|
||||
|
||||
this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable"));
|
||||
|
||||
|
||||
}
|
||||
|
||||
private void logParameters(){
|
||||
logger.info("query: " + this.cypherQuery);
|
||||
logger.info("cypherQuery: " + this.cypherQuery);
|
||||
logger.info("outputField: " + this.outputField);
|
||||
}
|
||||
|
||||
@@ -52,6 +41,7 @@ public class QueryNeo4JSolver extends StepSolver {
|
||||
logger.info("Solving step: " + this.step.getName());
|
||||
loadParameters();
|
||||
logParameters();
|
||||
|
||||
|
||||
|
||||
try (Session session = graphDriver.session()) {
|
||||
|
||||
@@ -5,6 +5,7 @@ import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.models.ScenarioStep;
|
||||
import com.olympus.hermione.services.Neo4JUitilityService;
|
||||
|
||||
public class StepSolver {
|
||||
|
||||
@@ -14,17 +15,20 @@ public class StepSolver {
|
||||
protected VectorStore vectorStore;
|
||||
protected ChatModel chatModel;
|
||||
protected Driver graphDriver;
|
||||
|
||||
protected Neo4JUitilityService neo4JUitilityService;
|
||||
|
||||
|
||||
public ScenarioExecution solveStep(){
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
return this.scenarioExecution;
|
||||
};
|
||||
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){
|
||||
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){
|
||||
this.scenarioExecution = scenarioExecution;
|
||||
this.step = step;
|
||||
this.vectorStore = vectorStore;
|
||||
this.chatModel = chatModel;
|
||||
this.graphDriver = graphDriver;
|
||||
this.neo4JUitilityService = neo4JUitilityService;
|
||||
System.out.println("Initializing StepSolver");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user