diff --git a/src/main/java/com/olympus/hermione/security/utility/JwtTokenProvider.java b/src/main/java/com/olympus/hermione/security/utility/JwtTokenProvider.java index d2f0b34..248cec6 100644 --- a/src/main/java/com/olympus/hermione/security/utility/JwtTokenProvider.java +++ b/src/main/java/com/olympus/hermione/security/utility/JwtTokenProvider.java @@ -3,6 +3,7 @@ package com.olympus.hermione.security.utility; import org.springframework.stereotype.Component; import io.jsonwebtoken.*; +import io.jsonwebtoken.io.Decoders; import io.jsonwebtoken.security.Keys; import io.jsonwebtoken.security.SignatureException; import jakarta.servlet.http.HttpServletRequest; @@ -18,6 +19,16 @@ import java.util.Date; public class JwtTokenProvider { Key key = Keys.secretKeyFor(SignatureAlgorithm.HS512); + + + public static final String SECRET = "5367566B59703373367639792F423F4528482B4D6251655468576D5A71347437"; + + + private Key getSignKey() { + byte[] keyBytes = Decoders.BASE64.decode(SECRET); + return Keys.hmacShaKeyFor(keyBytes); + } + public String createToken(Authentication authentication) { @@ -29,7 +40,7 @@ public class JwtTokenProvider { .setSubject(userDetails.getUsername()) .setIssuedAt(new Date()) .setExpiration(expiryDate) - .signWith(SignatureAlgorithm.HS512, key) + .signWith(getSignKey(), SignatureAlgorithm.HS256) .compact(); } @@ -47,7 +58,7 @@ public class JwtTokenProvider { public boolean validateToken(String token) { try { - Jwts.parser().setSigningKey(key).parseClaimsJws(token); + Jwts.parser().setSigningKey(getSignKey()).parseClaimsJws(token); return true; } catch (MalformedJwtException ex) { log.error("Invalid JWT token"); @@ -67,7 +78,7 @@ public class JwtTokenProvider { public String getUsername(String token) { return Jwts.parser() - .setSigningKey(key) + .setSigningKey(getSignKey()) .parseClaimsJws(token) .getBody() .getSubject(); diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index f7ccd5c..64c73ed 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -44,14 +44,12 @@ public class ScenarioExecutionService { @Autowired Driver graphDriver; + @Autowired + Neo4JUitilityService neo4JUitilityService; + private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); - /* - public String executeScenarioNew(ScenarioExecutionInput scenarioExecutionInput){ - - } - */ public String executeScenario(String scenarioId, String input){ @@ -117,6 +115,7 @@ public class ScenarioExecutionService { scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); scenarioOutput.setStatus("OK"); + return scenarioOutput; }else{ @@ -132,7 +131,7 @@ public class ScenarioExecutionService { private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){ - logger.info("Executing step: " + step.getName()); + logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); StepSolver solver=new StepSolver(); switch (step.getType()) { case "RAG_QUERY": @@ -148,7 +147,10 @@ public class ScenarioExecutionService { break; } - solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver); + logger.info("Initializing step: " + step.getName()); + solver.init(step, scenarioExecution, vectorStore,chatModel,graphDriver,neo4JUitilityService); + + logger.info("Solving step: " + step.getName()); //must add try catch in order to catch exceptions and step execution errors ScenarioExecution scenarioExecutionNew = solver.solveStep(); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index b2e7279..856cc68 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -18,15 +18,25 @@ public class BasicAIPromptSolver extends StepSolver { private String qai_system_prompt_template; private String qai_user_input; private String qai_output_variable; + private boolean qai_load_graph_schema=false; Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); private void loadParameters(){ logger.info("Loading parameters"); - + + if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){ + this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema")); + logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema); + } + + if(this.qai_load_graph_schema){ + logger.info("Loading graph schema"); + this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString()); + } + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); - this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template")); this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input")); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java index 1163e33..66d6a12 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/QueryNeo4JSolver.java @@ -1,15 +1,8 @@ package com.olympus.hermione.stepSolvers; -import java.util.ArrayList; -import java.util.List; - -import org.neo4j.driver.Driver; import org.neo4j.driver.Result; import org.neo4j.driver.Session; import org.slf4j.LoggerFactory; -import org.springframework.ai.document.Document; -import org.springframework.ai.vectorstore.SearchRequest; -import org.springframework.beans.factory.annotation.Autowired; import org.neo4j.driver.Record; import org.neo4j.driver.Value; @@ -26,22 +19,18 @@ public class QueryNeo4JSolver extends StepSolver { private String cypherQuery; private String outputField; - private void loadParameters(){ logger.info("Loading parameters"); - AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); this.cypherQuery = attributeParser.parse((String) this.step.getAttributes().get("cypher_query")); this.outputField = attributeParser.parse((String) this.step.getAttributes().get("output_variable")); - - } private void logParameters(){ - logger.info("query: " + this.cypherQuery); + logger.info("cypherQuery: " + this.cypherQuery); logger.info("outputField: " + this.outputField); } @@ -52,6 +41,7 @@ public class QueryNeo4JSolver extends StepSolver { logger.info("Solving step: " + this.step.getName()); loadParameters(); logParameters(); + try (Session session = graphDriver.session()) { diff --git a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java index 67b3721..0360de4 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/StepSolver.java @@ -5,6 +5,7 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.ai.vectorstore.VectorStore; import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.models.ScenarioStep; +import com.olympus.hermione.services.Neo4JUitilityService; public class StepSolver { @@ -14,17 +15,20 @@ public class StepSolver { protected VectorStore vectorStore; protected ChatModel chatModel; protected Driver graphDriver; - + protected Neo4JUitilityService neo4JUitilityService; + + public ScenarioExecution solveStep(){ System.out.println("Solving step: " + this.step.getName()); return this.scenarioExecution; }; - public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver){ + public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel,Driver graphDriver,Neo4JUitilityService neo4JUitilityService){ this.scenarioExecution = scenarioExecution; this.step = step; this.vectorStore = vectorStore; this.chatModel = chatModel; this.graphDriver = graphDriver; + this.neo4JUitilityService = neo4JUitilityService; System.out.println("Initializing StepSolver"); } } \ No newline at end of file