Initial commit

This commit is contained in:
andrea.terzani
2024-07-29 08:48:44 +02:00
commit 157dacbe31
22 changed files with 1086 additions and 0 deletions

View File

@@ -0,0 +1,13 @@
package com.olympus.hermione;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class HermioneApplication {
public static void main(String[] args) {
SpringApplication.run(HermioneApplication.class, args);
}
}

View File

@@ -0,0 +1,28 @@
package com.olympus.hermione.controllers;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RestController;
import com.olympus.hermione.models.Scenario;
import com.olympus.hermione.repository.ScenarioRepository;
import org.springframework.web.bind.annotation.RequestBody;
@RestController
public class ScenarioController {
@Autowired
ScenarioRepository scenarioRepository;
@GetMapping("/scenarios")
public Iterable<Scenario> getScenarios() {
return scenarioRepository.findAll();
}
@PostMapping("scenarios")
public Scenario createScenario(@RequestBody Scenario scenario) {
return scenarioRepository.save(scenario);
}
}

View File

@@ -0,0 +1,21 @@
package com.olympus.hermione.controllers;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RestController;
import com.olympus.hermione.services.ScenarioExecutionService;
@RestController
public class TestController {
@Autowired
ScenarioExecutionService scenarioExecutionService;
@GetMapping("/test/scenario_execution")
public String testScenarioExecution(){
String result = scenarioExecutionService.executeScenario("66a5ff66fc2ae458a491f161","What's the purpose if an inputset in ATF? How it can be defined?");
return result;
}
}

View File

@@ -0,0 +1,18 @@
package com.olympus.hermione.models;
import java.util.List;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;
import lombok.Getter;
import lombok.Setter;
@Document(collection = "scenarios")
@Getter @Setter
public class Scenario {
@Id
private String id;
private String name;
private List<ScenarioStep> steps;
}

View File

@@ -0,0 +1,25 @@
package com.olympus.hermione.models;
import java.util.HashMap;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;
import lombok.Getter;
import lombok.Setter;
@Document(collection = "scenario_executions")
@Getter
@Setter
public class ScenarioExecution {
@Id
private String id;
private HashMap<String,Object> execSharedMap;
private Scenario scenario;
private String scenarioId;
}

View File

@@ -0,0 +1,23 @@
package com.olympus.hermione.models;
import java.util.HashMap;
import java.util.List;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class ScenarioStep {
private String type;
private String description;
private String name;
private int order;
private HashMap<String,Object> attributes;
}

View File

@@ -0,0 +1,14 @@
package com.olympus.hermione.models;
import lombok.Getter;
import lombok.Setter;
@Getter
@Setter
public class ScenarioStepAttribute {
private String name;
private String value;
}

View File

@@ -0,0 +1,13 @@
package com.olympus.hermione.repository;
import org.springframework.data.repository.CrudRepository;
import org.springframework.stereotype.Repository;
import com.olympus.hermione.models.ScenarioExecution;
@Repository
public interface ScenarioExecutionRepository extends CrudRepository<ScenarioExecution, String> {
}

View File

@@ -0,0 +1,11 @@
package com.olympus.hermione.repository;
import org.springframework.data.repository.CrudRepository;
import org.springframework.stereotype.Repository;
import com.olympus.hermione.models.Scenario;
@Repository
public interface ScenarioRepository extends CrudRepository<Scenario, String> {
}

View File

@@ -0,0 +1,103 @@
package com.olympus.hermione.services;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.Vector;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import com.olympus.hermione.models.Scenario;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
import com.olympus.hermione.stepSolvers.StepSolver;
import org.slf4j.Logger;
@Service
public class ScenarioExecutionService {
@Autowired
private ScenarioRepository scenarioRepository;
@Autowired
private ScenarioExecutionRepository scenarioExecutionRepository;
@Autowired
VectorStore vectorStore;
@Autowired
ChatModel chatModel;
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
public String executeScenario(String scenarioId, String input){
Optional<Scenario> o_scenario = scenarioRepository.findById(scenarioId);
if(o_scenario.isPresent()){
Scenario scenario = o_scenario.get();
logger.info("Executing scenario: " + scenario.getName());
ScenarioExecution scenarioExecution = new ScenarioExecution();
scenarioExecution.setScenario(scenario);
HashMap<String, Object> execSharedMap = new HashMap<String, Object>();
execSharedMap.put("user_input", input);
scenarioExecution.setExecSharedMap(execSharedMap);
scenarioExecutionRepository.save(scenarioExecution);
List<ScenarioStep> steps = scenario.getSteps();
steps.sort(Comparator.comparingInt(ScenarioStep::getOrder));
for (ScenarioStep step : steps) {
executeScenarioStep(step, scenarioExecution, step.getOrder());
}
return scenarioExecution.getExecSharedMap().get("scenario_output").toString();
}else{
logger.error("Scenario not found with id: " + scenarioId);
}
;
return "";
}
private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution, int stepIndex){
logger.info("Executing step: " + step.getName());
StepSolver solver=new StepSolver();
switch (step.getType()) {
case "RAG_QUERY":
solver = new BasicQueryRagSolver();
break;
case "SIMPLE_QUERY_AI":
solver = new BasicAIPromptSolver();
break;
default:
break;
}
solver.init(step, scenarioExecution, vectorStore,chatModel);
//must add try catch in order to catch exceptions and step execution errors
ScenarioExecution scenarioExecutionNew = solver.solveStep();
scenarioExecutionRepository.save(scenarioExecutionNew);
}
}

View File

@@ -0,0 +1,5 @@
package com.olympus.hermione.services;
public class TestService {
}

View File

@@ -0,0 +1,67 @@
package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class BasicAIPromptSolver extends StepSolver {
private String qai_system_prompt_template;
private String qai_user_input;
private String qai_output_variable;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
}
@Override
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
loadParameters();
String userText = this.qai_user_input;
Message userMessage = new UserMessage(userText);
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
List<Generation> response = chatModel.call(prompt).getResults();
String output = response.get(0).getOutput().getContent();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
return this.scenarioExecution;
}
}

View File

@@ -0,0 +1,102 @@
package com.olympus.hermione.stepSolvers;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class BasicQueryRagSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
// PARAMETERS
// rag_input : the field of the sharedMap to use as rag input
// rag_filter : the field of the sharedMap to use as rag filter
// rag_topk : the number of topk results to return
// rag_threshold : the similarity threshold to use
// rag_output_variable : the field of the sharedMap to store the output
private String query;
private String rag_filter;
private int topk;
private double threshold;
private String outputField;
private void loadParameters(){
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.query = attributeParser.parse((String) this.step.getAttributes().get("rag_input"));
this.rag_filter = attributeParser.parse((String) this.step.getAttributes().get("rag_filter"));
if(this.step.getAttributes().containsKey("rag_topk")){
this.topk = (int) this.step.getAttributes().get("rag_topk");
}else{
this.topk = 4;
}
if(this.step.getAttributes().containsKey("rag_threshold")){
this.threshold = (double) this.step.getAttributes().get("rag_threshold");
}else{
this.threshold = 0.8;
}
this.outputField = (String) this.step.getAttributes().get("rag_output_variable");
}
private void logParameters(){
logger.info("query: " + this.query);
logger.info("rag_filter: " + this.rag_filter);
logger.info("topk: " + this.topk);
logger.info("threshold: " + this.threshold);
logger.info("outputField: " + this.outputField);
}
@Override
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
loadParameters();
logParameters();
SearchRequest request = SearchRequest.defaults()
.withQuery(this.query)
.withTopK(this.topk)
.withSimilarityThreshold(this.threshold);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request.withFilterExpression(this.rag_filter);
}
List<Document> docs = this.vectorStore.similaritySearch(request);
List<String> result = new ArrayList<String>();
for (Document doc : docs) {
result.add(doc.getContent());
}
//concatenate the content of the results into a single string
String resultString = String.join("\n", result);
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
return this.scenarioExecution;
}
}

View File

@@ -0,0 +1,35 @@
package com.olympus.hermione.stepSolvers;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.vectorstore.VectorStore;
import com.olympus.hermione.models.Scenario;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import lombok.Getter;
import lombok.Setter;
public class StepSolver {
protected ScenarioExecution scenarioExecution;
protected ScenarioStep step;
protected VectorStore vectorStore;
protected ChatModel chatModel;
public ScenarioExecution solveStep(){
System.out.println("Solving step: " + this.step.getName());
return this.scenarioExecution;
};
public void init(ScenarioStep step, ScenarioExecution scenarioExecution, VectorStore vectorStore,ChatModel chatModel){
this.scenarioExecution = scenarioExecution;
this.step = step;
this.vectorStore = vectorStore;
this.chatModel = chatModel;
System.out.println("Initializing StepSolver");
};
}

View File

@@ -0,0 +1,23 @@
package com.olympus.hermione.utility;
import org.stringtemplate.v4.ST;
import com.olympus.hermione.models.ScenarioExecution;
public class AttributeParser {
private ScenarioExecution scenario_execution;
public AttributeParser(ScenarioExecution scenario_execution){
this.scenario_execution = scenario_execution;
}
public String parse( String attribute_value ){
ST template = new ST(attribute_value, '$', '$');
template.add("shared", this.scenario_execution.getExecSharedMap());
return template.render();
}
}

View File

@@ -0,0 +1,13 @@
spring.application.name=hermione
server.port=8081
spring.data.mongodb.uri=mongodb+srv://olympus_adm:26111979@olympus.l6qor4p.mongodb.net/?retryWrites=true&w=majority&appName=Olympus
spring.data.mongodb.database=olympus
spring.data.mongodb.username=olympus_adm
spring.data.mongodb.password=XXXXXX
spring.ai.vectorstore.mongodb.indexName=vector_index
spring.ai.vectorstore.mongodb.collection-name=vector_store
spring.ai.vectorstore.mongodb.initialize-schema=false
spring.ai.openai.api-key=sk-proj-k4jrXXXUYQN8yQG2vNmWT3BlbkFJ0Ge9EfKcrMxduVFQZlyO

View File

@@ -0,0 +1,13 @@
package com.olympus.hermione;
import org.junit.jupiter.api.Test;
import org.springframework.boot.test.context.SpringBootTest;
@SpringBootTest
class HermioneApplicationTests {
@Test
void contextLoads() {
}
}