Initial commit
This commit is contained in:
13
src/main/java/com/olympus/hermione/HermioneApplication.java
Normal file
13
src/main/java/com/olympus/hermione/HermioneApplication.java
Normal file
@@ -0,0 +1,13 @@
|
||||
package com.olympus.hermione;
|
||||
|
||||
import org.springframework.boot.SpringApplication;
|
||||
import org.springframework.boot.autoconfigure.SpringBootApplication;
|
||||
|
||||
@SpringBootApplication
|
||||
public class HermioneApplication {
|
||||
|
||||
public static void main(String[] args) {
|
||||
SpringApplication.run(HermioneApplication.class, args);
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
package com.olympus.hermione.controllers;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import com.olympus.hermione.models.Scenario;
|
||||
import com.olympus.hermione.repository.ScenarioRepository;
|
||||
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
|
||||
@RestController
|
||||
public class ScenarioController {
|
||||
|
||||
@Autowired
|
||||
ScenarioRepository scenarioRepository;
|
||||
|
||||
@GetMapping("/scenarios")
|
||||
public Iterable<Scenario> getScenarios() {
|
||||
return scenarioRepository.findAll();
|
||||
}
|
||||
|
||||
@PostMapping("scenarios")
|
||||
public Scenario createScenario(@RequestBody Scenario scenario) {
|
||||
return scenarioRepository.save(scenario);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.olympus.hermione.controllers;
|
||||
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import com.olympus.hermione.services.ScenarioExecutionService;
|
||||
|
||||
@RestController
|
||||
public class TestController {
|
||||
|
||||
@Autowired
|
||||
ScenarioExecutionService scenarioExecutionService;
|
||||
|
||||
@GetMapping("/test/scenario_execution")
|
||||
public String testScenarioExecution(){
|
||||
|
||||
String result = scenarioExecutionService.executeScenario("66a5ff66fc2ae458a491f161","What's the purpose if an inputset in ATF? How it can be defined?");
|
||||
return result;
|
||||
}
|
||||
}
|
||||
18
src/main/java/com/olympus/hermione/models/Scenario.java
Normal file
18
src/main/java/com/olympus/hermione/models/Scenario.java
Normal file
@@ -0,0 +1,18 @@
|
||||
package com.olympus.hermione.models;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.data.annotation.Id;
|
||||
import org.springframework.data.mongodb.core.mapping.Document;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Document(collection = "scenarios")
|
||||
@Getter @Setter
|
||||
public class Scenario {
|
||||
@Id
|
||||
private String id;
|
||||
private String name;
|
||||
private List<ScenarioStep> steps;
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
package com.olympus.hermione.models;
|
||||
|
||||
import java.util.HashMap;
|
||||
|
||||
import org.springframework.data.annotation.Id;
|
||||
import org.springframework.data.mongodb.core.mapping.Document;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Document(collection = "scenario_executions")
|
||||
@Getter
|
||||
@Setter
|
||||
public class ScenarioExecution {
|
||||
|
||||
@Id
|
||||
private String id;
|
||||
private HashMap<String,Object> execSharedMap;
|
||||
|
||||
private Scenario scenario;
|
||||
|
||||
private String scenarioId;
|
||||
|
||||
|
||||
}
|
||||
23
src/main/java/com/olympus/hermione/models/ScenarioStep.java
Normal file
23
src/main/java/com/olympus/hermione/models/ScenarioStep.java
Normal file
@@ -0,0 +1,23 @@
|
||||
package com.olympus.hermione.models;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.springframework.data.annotation.Id;
|
||||
import org.springframework.data.mongodb.core.mapping.Document;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class ScenarioStep {
|
||||
|
||||
private String type;
|
||||
private String description;
|
||||
private String name;
|
||||
private int order;
|
||||
|
||||
private HashMap<String,Object> attributes;
|
||||
|
||||
}
|
||||
@@ -0,0 +1,14 @@
|
||||
package com.olympus.hermione.models;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
@Getter
|
||||
@Setter
|
||||
public class ScenarioStepAttribute {
|
||||
|
||||
private String name;
|
||||
private String value;
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.olympus.hermione.repository;
|
||||
|
||||
import org.springframework.data.repository.CrudRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
|
||||
@Repository
|
||||
public interface ScenarioExecutionRepository extends CrudRepository<ScenarioExecution, String> {
|
||||
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
package com.olympus.hermione.repository;
|
||||
|
||||
import org.springframework.data.repository.CrudRepository;
|
||||
import org.springframework.stereotype.Repository;
|
||||
|
||||
import com.olympus.hermione.models.Scenario;
|
||||
|
||||
@Repository
|
||||
public interface ScenarioRepository extends CrudRepository<Scenario, String> {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,103 @@
|
||||
package com.olympus.hermione.services;
|
||||
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.Vector;
|
||||
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.stereotype.Service;
|
||||
|
||||
import com.olympus.hermione.models.Scenario;
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.models.ScenarioStep;
|
||||
import com.olympus.hermione.repository.ScenarioExecutionRepository;
|
||||
import com.olympus.hermione.repository.ScenarioRepository;
|
||||
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
|
||||
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
|
||||
import com.olympus.hermione.stepSolvers.StepSolver;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
|
||||
@Service
|
||||
public class ScenarioExecutionService {
|
||||
|
||||
@Autowired
|
||||
private ScenarioRepository scenarioRepository;
|
||||
|
||||
@Autowired
|
||||
private ScenarioExecutionRepository scenarioExecutionRepository;
|
||||
|
||||
@Autowired
|
||||
VectorStore vectorStore;
|
||||
|
||||
@Autowired
|
||||
ChatModel chatModel;
|
||||
|
||||
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
|
||||
|
||||
public String executeScenario(String scenarioId, String input){
|
||||
|
||||
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioId);
|
||||
|
||||
if(o_scenario.isPresent()){
|
||||
Scenario scenario = o_scenario.get();
|
||||
logger.info("Executing scenario: " + scenario.getName());
|
||||
|
||||
ScenarioExecution scenarioExecution = new ScenarioExecution();
|
||||
scenarioExecution.setScenario(scenario);
|
||||
|
||||
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
|
||||
execSharedMap.put("user_input", input);
|
||||
scenarioExecution.setExecSharedMap(execSharedMap);
|
||||
scenarioExecutionRepository.save(scenarioExecution);
|
||||
|
||||
List<ScenarioStep> steps = scenario.getSteps();
|
||||
|
||||
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
|
||||
for (ScenarioStep step : steps) {
|
||||
executeScenarioStep(step, scenarioExecution, step.getOrder());
|
||||
}
|
||||
|
||||
return scenarioExecution.getExecSharedMap().get("scenario_output").toString();
|
||||
|
||||
}else{
|
||||
logger.error("Scenario not found with id: " + scenarioId);
|
||||
}
|
||||
|
||||
;
|
||||
return "";
|
||||
}
|
||||
|
||||
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
|
||||
logger.info("Executing step: " + step.getName());
|
||||
StepSolver solver=new StepSolver();
|
||||
switch (step.getType()) {
|
||||
case "RAG_QUERY":
|
||||
solver = new BasicQueryRagSolver();
|
||||
break;
|
||||
case "SIMPLE_QUERY_AI":
|
||||
solver = new BasicAIPromptSolver();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
|
||||
solver.init(step, scenarioExecution, vectorStore,chatModel);
|
||||
|
||||
//must add try catch in order to catch exceptions and step execution errors
|
||||
ScenarioExecution scenarioExecutionNew = solver.solveStep();
|
||||
|
||||
scenarioExecutionRepository.save(scenarioExecutionNew);
|
||||
|
||||
|
||||
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,5 @@
|
||||
package com.olympus.hermione.services;
|
||||
|
||||
public class TestService {
|
||||
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.model.Generation;
|
||||
import org.springframework.ai.chat.prompt.Prompt;
|
||||
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
public class BasicAIPromptSolver extends StepSolver {
|
||||
|
||||
private String qai_system_prompt_template;
|
||||
private String qai_user_input;
|
||||
private String qai_output_variable;
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
|
||||
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
|
||||
|
||||
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScenarioExecution solveStep(){
|
||||
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
loadParameters();
|
||||
|
||||
|
||||
String userText = this.qai_user_input;
|
||||
|
||||
Message userMessage = new UserMessage(userText);
|
||||
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
|
||||
|
||||
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
|
||||
|
||||
List<Generation> response = chatModel.call(prompt).getResults();
|
||||
|
||||
String output = response.get(0).getOutput().getContent();
|
||||
|
||||
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
}
|
||||
@@ -0,0 +1,102 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.document.Document;
|
||||
import org.springframework.ai.vectorstore.SearchRequest;
|
||||
|
||||
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
public class BasicQueryRagSolver extends StepSolver {
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
|
||||
|
||||
// PARAMETERS
|
||||
// rag_input : the field of the sharedMap to use as rag input
|
||||
// rag_filter : the field of the sharedMap to use as rag filter
|
||||
// rag_topk : the number of topk results to return
|
||||
// rag_threshold : the similarity threshold to use
|
||||
// rag_output_variable : the field of the sharedMap to store the output
|
||||
private String query;
|
||||
private String rag_filter;
|
||||
private int topk;
|
||||
private double threshold;
|
||||
private String outputField;
|
||||
|
||||
|
||||
|
||||
private void loadParameters(){
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input"));
|
||||
|
||||
this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter"));
|
||||
|
||||
if(this.step.getAttributes().containsKey("rag_topk")){
|
||||
this.topk = (int) this.step.getAttributes().get("rag_topk");
|
||||
}else{
|
||||
this.topk = 4;
|
||||
}
|
||||
|
||||
if(this.step.getAttributes().containsKey("rag_threshold")){
|
||||
this.threshold = (double) this.step.getAttributes().get("rag_threshold");
|
||||
}else{
|
||||
this.threshold = 0.8;
|
||||
}
|
||||
|
||||
this.outputField = (String) this.step.getAttributes().get("rag_output_variable");
|
||||
}
|
||||
|
||||
private void logParameters(){
|
||||
logger.info("query: " + this.query);
|
||||
logger.info("rag_filter: " + this.rag_filter);
|
||||
logger.info("topk: " + this.topk);
|
||||
logger.info("threshold: " + this.threshold);
|
||||
logger.info("outputField: " + this.outputField);
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public ScenarioExecution solveStep(){
|
||||
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
loadParameters();
|
||||
logParameters();
|
||||
|
||||
SearchRequest request = SearchRequest.defaults()
|
||||
.withQuery(this.query)
|
||||
.withTopK(this.topk)
|
||||
.withSimilarityThreshold(this.threshold);
|
||||
|
||||
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
|
||||
request.withFilterExpression(this.rag_filter);
|
||||
}
|
||||
|
||||
List<Document> docs = this.vectorStore.similaritySearch(request);
|
||||
|
||||
List<String> result = new ArrayList<String>();
|
||||
for (Document doc : docs) {
|
||||
result.add(doc.getContent());
|
||||
}
|
||||
|
||||
//concatenate the content of the results into a single string
|
||||
String resultString = String.join("\n", result);
|
||||
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import org.springframework.ai.chat.client.ChatClient;
|
||||
import org.springframework.ai.chat.model.ChatModel;
|
||||
import org.springframework.ai.vectorstore.VectorStore;
|
||||
|
||||
import com.olympus.hermione.models.Scenario;
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.models.ScenarioStep;
|
||||
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
|
||||
public class StepSolver {
|
||||
|
||||
|
||||
|
||||
protected ScenarioExecution scenarioExecution;
|
||||
protected ScenarioStep step;
|
||||
protected VectorStore vectorStore;
|
||||
protected ChatModel chatModel;
|
||||
|
||||
public ScenarioExecution solveStep(){
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
return this.scenarioExecution;
|
||||
};
|
||||
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel){
|
||||
this.scenarioExecution = scenarioExecution;
|
||||
this.step = step;
|
||||
this.vectorStore = vectorStore;
|
||||
this.chatModel = chatModel;
|
||||
System.out.println("Initializing StepSolver");
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
package com.olympus.hermione.utility;
|
||||
|
||||
import org.stringtemplate.v4.ST;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
|
||||
public class AttributeParser {
|
||||
|
||||
private ScenarioExecution scenario_execution;
|
||||
|
||||
public AttributeParser(ScenarioExecution scenario_execution){
|
||||
this.scenario_execution = scenario_execution;
|
||||
}
|
||||
|
||||
public String parse( String attribute_value ){
|
||||
|
||||
ST template = new ST(attribute_value, '$', '$');
|
||||
template.add("shared", this.scenario_execution.getExecSharedMap());
|
||||
|
||||
return template.render();
|
||||
|
||||
}
|
||||
}
|
||||
13
src/main/resources/application.properties
Normal file
13
src/main/resources/application.properties
Normal file
@@ -0,0 +1,13 @@
|
||||
spring.application.name=hermione
|
||||
server.port=8081
|
||||
|
||||
spring.data.mongodb.uri=mongodb+srv://olympus_adm:26111979@olympus.l6qor4p.mongodb.net/?retryWrites=true&w=majority&appName=Olympus
|
||||
spring.data.mongodb.database=olympus
|
||||
spring.data.mongodb.username=olympus_adm
|
||||
spring.data.mongodb.password=XXXXXX
|
||||
|
||||
spring.ai.vectorstore.mongodb.indexName=vector_index
|
||||
spring.ai.vectorstore.mongodb.collection-name=vector_store
|
||||
spring.ai.vectorstore.mongodb.initialize-schema=false
|
||||
|
||||
spring.ai.openai.api-key=sk-proj-k4jrXXXUYQN8yQG2vNmWT3BlbkFJ0Ge9EfKcrMxduVFQZlyO
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.olympus.hermione;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
|
||||
@SpringBootTest
|
||||
class HermioneApplicationTests {
|
||||
|
||||
@Test
|
||||
void contextLoads() {
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user