diff --git a/pom.xml b/pom.xml
index 970ccb1..735ac9d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -28,8 +28,7 @@
21
-
- 1.0.0-M6
+ 1.0.0
2023.0.3
@@ -61,38 +60,26 @@
spring-boot-starter-oauth2-resource-server
-
-
- org.json
- json
- 20250107
-
+
+ org.json
+ json
+ 20250107
+
-
org.springframework.ai
- spring-ai-openai-spring-boot-starter
+ spring-ai-starter-model-azure-openai
+
org.springframework.ai
- spring-ai-azure-openai
-
-
- org.springframework.ai
- spring-ai-azure-store
-
-
- org.springframework.ai
- spring-ai-chroma-store
+ spring-ai-starter-vector-store-chroma
+
org.neo4j.driver
neo4j-java-driver
- 5.8.0
@@ -100,12 +87,12 @@
spring-boot-starter-test
test
-
- org.projectlombok
- lombok
- 1.18.34
-
-
+
+ org.projectlombok
+ lombok
+ 1.18.38
+ provided
+
io.jsonwebtoken
jjwt-api
@@ -153,7 +140,6 @@
com.google.code.gson
gson
- 2.8.8
diff --git a/src/main/java/com/olympus/hermione/config/AzureAIConfig.java b/src/main/java/com/olympus/hermione/config/AzureAIConfig.java
deleted file mode 100644
index 81a4336..0000000
--- a/src/main/java/com/olympus/hermione/config/AzureAIConfig.java
+++ /dev/null
@@ -1,26 +0,0 @@
-package com.olympus.hermione.config;
-
-import java.time.Duration;
-
-import org.springframework.context.annotation.Bean;
-import org.springframework.context.annotation.Configuration;
-import org.springframework.web.client.RestTemplate;
-import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAIClientBuilderCustomizer;
-import com.azure.core.http.HttpClient;
-import com.azure.core.util.HttpClientOptions;
-
-
-@Configuration
-public class AzureAIConfig {
-
- @Bean
- public AzureOpenAIClientBuilderCustomizer responseTimeoutCustomizer() {
- return openAiClientBuilder -> {
- HttpClientOptions clientOptions = new HttpClientOptions()
- .setResponseTimeout(Duration.ofMinutes(5));
- openAiClientBuilder.httpClient(HttpClient.createDefault(clientOptions));
- };
- }
-
-
-}
\ No newline at end of file
diff --git a/src/main/java/com/olympus/hermione/config/HermioneConfig.java b/src/main/java/com/olympus/hermione/config/HermioneConfig.java
index 5daa131..5619e7a 100644
--- a/src/main/java/com/olympus/hermione/config/HermioneConfig.java
+++ b/src/main/java/com/olympus/hermione/config/HermioneConfig.java
@@ -2,7 +2,6 @@ package com.olympus.hermione.config;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@@ -11,24 +10,13 @@ import org.springframework.context.annotation.Primary;
@Configuration
public class HermioneConfig {
- private final OpenAiEmbeddingModel openAiEmbeddingModel;
private final AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel;
- // Autowiring beans from both libraries
- public HermioneConfig(OpenAiEmbeddingModel myServiceLib1,
- AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) {
- this.openAiEmbeddingModel = myServiceLib1;
+ public HermioneConfig( AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) {
this.azureOpenAiEmbeddingModel = azureOpenAiEmbeddingModel;
}
- // Re-declaring and marking one as primary
- @Bean
- public EmbeddingModel primaryMyService() {
- return openAiEmbeddingModel; // Choose the one you want as primary
- }
-
- // Optionally declare the other bean without primary
@Bean
@Primary
public EmbeddingModel secondaryMyService() {
diff --git a/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java b/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java
deleted file mode 100644
index 83f4571..0000000
--- a/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java
+++ /dev/null
@@ -1,50 +0,0 @@
-package com.olympus.hermione.config;
-
-import com.azure.core.credential.AzureKeyCredential;
-import com.azure.search.documents.indexes.SearchIndexClient;
-import com.azure.search.documents.indexes.SearchIndexClientBuilder;
-import org.springframework.ai.embedding.EmbeddingModel;
-import org.springframework.ai.vectorstore.VectorStore;
-import org.springframework.ai.vectorstore.azure.AzureVectorStore;
-import org.springframework.beans.factory.annotation.Qualifier;
-import org.springframework.beans.factory.annotation.Value;
-import org.springframework.context.annotation.Bean;
-import org.springframework.context.annotation.Configuration;
-
-import java.util.ArrayList;
-import java.util.List;
-
-@Configuration
-public class VectorStoreConfig {
-
-
- @Value("${spring.ai.vectorstore.azure.api-key}")
- private String azureKey;
- @Value("${spring.ai.vectorstore.azure.url}")
- private String azureEndpoint;
- @Value("${spring.ai.vectorstore.azure.initialize-schema}")
- private boolean initSchema;
-
- @Bean
- public SearchIndexClient searchIndexClient() {
- return new SearchIndexClientBuilder().endpoint(azureEndpoint)
- .credential(new AzureKeyCredential(azureKey))
- .buildClient();
- }
-
- @Bean
- public VectorStore vectorStore(SearchIndexClient searchIndexClient, @Qualifier("azureOpenAiEmbeddingModel") EmbeddingModel embeddingModel) {
- List fields = new ArrayList<>();
-
- fields.add(AzureVectorStore.MetadataField.text("KsApplicationName"));
- fields.add(AzureVectorStore.MetadataField.text("KsProjectName"));
- fields.add(AzureVectorStore.MetadataField.text("KsDoctype"));
- fields.add(AzureVectorStore.MetadataField.text("KsDocSource"));
- fields.add(AzureVectorStore.MetadataField.text("KsFileSource"));
- fields.add(AzureVectorStore.MetadataField.text("KsDocumentId"));
-
- //return new AzureVectorStore(searchIndexClient, embeddingModel,initSchema, fields);
- return null;
- }
-}
-
diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
index e70141c..2752554 100644
--- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
+++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java
@@ -24,11 +24,8 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
-import org.springframework.ai.chat.memory.InMemoryChatMemory;
+import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
-import org.springframework.ai.openai.OpenAiChatModel;
-import org.springframework.ai.openai.OpenAiChatOptions;
-import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -200,19 +197,16 @@ public class ScenarioExecutionService {
if (scenario.isUseChatMemory()) {
logger.info("Initializing chatClient with chat-memory advisor");
- ChatMemory chatMemory = new InMemoryChatMemory();
+ ChatMemory chatMemory = MessageWindowChatMemory.builder().build();
+
chatClient = ChatClient.builder(chatModel)
- .defaultAdvisors(
- new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor
- new SimpleLoggerAdvisor())
+ .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build() )
.build();
} else {
- logger.info("Initializing chatClient with simple logger advisor");
+ logger.info("Initializing chatClient without advisor");
chatClient = ChatClient.builder(chatModel)
- .defaultAdvisors(
- new SimpleLoggerAdvisor())
.build();
}
@@ -437,81 +431,46 @@ public class ScenarioExecutionService {
private ChatModel createChatModel(AiModel aiModel) {
switch (aiModel.getApiProvider()) {
case "AzureOpenAI":
- OpenAIClientBuilder openAIClient = new OpenAIClientBuilder()
- .credential(new AzureKeyCredential(aiModel.getApiKey()))
- .endpoint(aiModel.getEndpoint());
+
- AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder()
- .deploymentName(aiModel.getModel())
- .maxTokens(aiModel.getMaxTokens())
- .temperature(Double.valueOf(aiModel.getTemperature()))
- .build();
+ OpenAIClientBuilder openAIClientBuilder = new OpenAIClientBuilder()
+ .credential(new AzureKeyCredential(aiModel.getApiKey()))
+ .endpoint(aiModel.getEndpoint());
- AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions);
+ // Crea un builder con opzioni base
+ AzureOpenAiChatOptions.Builder optionsBuilder = AzureOpenAiChatOptions.builder()
+ .deploymentName(aiModel.getModel());
+
+ // Aggiungi maxTokens solo se non è nullo
+ if (aiModel.getMaxTokens() != null) {
+ optionsBuilder.maxTokens(aiModel.getMaxTokens());
+ logger.info("Setting maxTokens: " + aiModel.getMaxTokens());
+ }
+
+ // Aggiungi temperature solo se non è nullo
+ if (aiModel.getTemperature() != null) {
+ optionsBuilder.temperature(Double.valueOf(aiModel.getTemperature()));
+ logger.info("Setting temperature: " + aiModel.getTemperature());
+ }
+
+ AzureOpenAiChatOptions openAIChatOptions = optionsBuilder.build();
+
+ var chatModel = AzureOpenAiChatModel.builder()
+ .openAIClientBuilder(openAIClientBuilder)
+ .defaultOptions(openAIChatOptions)
+ .build();
+
logger.info("AI model used: " + aiModel.getModel());
- return azureOpenaichatModel;
+ return chatModel;
- case "OpenAI":
- OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey());
- OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder()
- .model(aiModel.getModel())
- .temperature(Double.valueOf(aiModel.getTemperature()))
- .maxTokens(aiModel.getMaxTokens())
- .build();
- OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi, openAiChatOptions);
- logger.info("AI model used: " + aiModel.getModel());
- return openaichatModel;
-
- case "GoogleGemini":
- OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder()
- .credential(new AzureKeyCredential(aiModel.getApiKey()))
- .endpoint(aiModel.getEndpoint());
-
- AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder()
- .deploymentName(aiModel.getModel())
- .maxTokens(aiModel.getMaxTokens())
- .temperature(Double.valueOf(aiModel.getTemperature()))
- .build();
-
- AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2,
- openAIChatOptions2);
-
- logger.info("AI model used : " + aiModel.getModel());
- return azureOpenaichatModel2;
default:
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
}
}
- /*
- * public List getListExecutionScenario(){
- * logger.info("getListProjectByUser function:");
- *
- * List lstScenarioExecution = null;
- *
- * User principal = (User)
- * SecurityContextHolder.getContext().getAuthentication().getPrincipal();
- *
- * if(principal.getSelectedApplication()!=null){
- * lstScenarioExecution =
- * scenarioExecutionRepository.getFromProjectAndAPP(principal.getId(),
- * principal.getSelectedProject().getInternal_name(),
- * principal.getSelectedApplication().getInternal_name(),
- * Sort.by(Sort.Direction.DESC, "startDate"));
- *
- * }else{
- * lstScenarioExecution =
- * scenarioExecutionRepository.getFromProject(principal.getId(),
- * principal.getSelectedProject().getInternal_name(),
- * Sort.by(Sort.Direction.DESC, "startDate"));
- * }
- *
- * return lstScenarioExecution;
- *
- * }
- */
+
public Page getListExecutionScenarioOptional(int page, int size, String scenarioName,
String executedBy, LocalDate startDate) {
@@ -582,12 +541,7 @@ public class ScenarioExecutionService {
criteriaList.add(Criteria.where("execSharedMap.user_input.selected_application")
.is(principal.getSelectedApplication().getInternal_name()));
}
- // here:
- // filtri
- // here:
- // here: vedi perchè non funziona link di view
- // here: da sistemare filtro data a backend e i vari equals
- // here: add ordinamento
+
for (Map.Entry entry : filters.entrySet()) {
String key = entry.getKey();
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java
index 3534d08..e6bcab9 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java
@@ -1,21 +1,18 @@
package com.olympus.hermione.stepSolvers;
-import ch.qos.logback.classic.Logger;
+import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
+import org.springframework.ai.chat.memory.ChatMemory;
+import org.springframework.ai.chat.metadata.Usage;
+import org.springframework.ai.chat.model.ChatResponse;
+
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
-import java.util.ArrayList;
-import java.util.List;
-
-import org.slf4j.LoggerFactory;
-import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
-import org.springframework.ai.chat.messages.Message;
-import org.springframework.ai.chat.messages.SystemMessage;
-import org.springframework.ai.chat.messages.UserMessage;
-import org.springframework.ai.chat.metadata.Usage;
+import ch.qos.logback.classic.Logger;
public class AdvancedAIPromptSolver extends StepSolver {
@@ -61,7 +58,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
logger.info("qai_custom_memory_id: {}", this.qai_custom_memory_id);
- //TODO: Add memory ID attribute to have the possibility of multiple conversations
+
this.qai_available_tools = (String) this.step.getAttributes().get("qai_available_tools");
logger.info("qai_available_tools: {}", this.qai_available_tools);
}
@@ -82,12 +79,22 @@ public class AdvancedAIPromptSolver extends StepSolver {
logger.info("Available tools specified: {}", this.qai_available_tools);
for(String tool: this.qai_available_tools.split(",")){
logger.info("Tool: {}", tool);
- if(tool.equals("SourceCodeTool")){
- logger.info("Instantiating SourceCodeTool");
- tools = new com.olympus.hermione.tools.SourceCodeTool(discoveryClient);
- } else {
- logger.warn("Unknown tool specified: {}", tool);
+
+ switch(tool.trim()){
+ case "DocumentationTool" -> {
+ logger.info("Instantiating DocumentationTool");
+ tools = new com.olympus.hermione.tools.DocumentationTool(
+ this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_application"),
+ this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_project"),
+ vectorStore);
+ }
+
+ case "SourceCodeTool" -> {
+ logger.info("Instantiating SourceCodeTool");
+ tools = new com.olympus.hermione.tools.SourceCodeTool(this.discoveryClient);
+ }
}
+
}
} else {
logger.info("No tools specified or available");
@@ -98,9 +105,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
resp = chatClient.prompt()
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
- .advisors(advisor -> advisor
- .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
- .param("chat_memory_response_size", 100))
+ .advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id))
.call();
}else{
logger.info("Calling chatClient with tools: {}", tools.getClass().getName());
@@ -108,12 +113,11 @@ public class AdvancedAIPromptSolver extends StepSolver {
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
.advisors(advisor -> advisor
- .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
- .param("chat_memory_response_size", 100))
+ .param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id))
.tools(tools)
.call();
}
-
+ ChatResponse response = null;
if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
logger.info("Output is of type CiaOutputEntity");
@@ -129,13 +133,13 @@ public class AdvancedAIPromptSolver extends StepSolver {
}
}else{
logger.info("Output is of type String");
- String output = resp.content();
- logger.debug("String response: {}", output);
- this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
+ response = resp.chatResponse();
+ this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, response.getResult().getOutput().getText());
}
- Usage usage = resp.chatResponse().getMetadata().getUsage();
- if (usage != null) {
+
+ if (response != null && response.getMetadata() != null) {
+ Usage usage = response.getMetadata().getUsage();
Integer usedTokens = usage.getTotalTokens();
logger.info("Used tokens: {}", usedTokens);
this.scenarioExecution.setUsedTokens(usedTokens);
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java
index 4eb9e7b..40a24f4 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java
@@ -158,42 +158,30 @@ public class AdvancedQueryRagSolver extends StepSolver {
}
- @Override
+ @Override
public ScenarioExecution solveStep(){
logger.info("Solving step: {}", this.step.getName());
loadParameters();
logParameters();
String elaboratedQuery = this.query;
- OlympusChatClient chatClient = null;
-
- String apiProvider = this.scenarioExecution.getScenario().getAiModel().getApiProvider();
- logger.info("AI Model API Provider: {}", apiProvider);
- if (apiProvider.equals("AzureOpenAI")) {
- logger.info("Using AzureOpenApiChatClient");
- chatClient = new AzureOpenApiChatClient();
- }
- if (apiProvider.equals("GoogleGemini")) {
- logger.info("Using GoogleGeminiChatClient");
- chatClient = new GoogleGeminiChatClient();
- }
-
- logger.info("Initializing chat client with endpoint: {}", this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint());
- chatClient.init(this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(),
- this.scenarioExecution.getScenario().getAiModel().getApiKey(),
- this.scenarioExecution.getScenario().getAiModel().getMaxTokens());
if(this.enableRephrase){
logger.info("Rephrasing query with prompt: {}", this.rephrasePrompt);
- OlympusChatClientResponse resp = chatClient.getChatCompletion(this.rephrasePrompt);
- elaboratedQuery = resp.getContent();
+ elaboratedQuery = chatClient.prompt()
+ .user(this.rephrasePrompt)
+ .call()
+ .content();
+
logger.info("Rephrased query: {}", elaboratedQuery);
}
if(this.enableMultiPhrase){
logger.info("Generating multi-phrase queries with prompt: {}", this.MultiPhrasePrompt);
- OlympusChatClientResponse resp = chatClient.getChatCompletion(this.MultiPhrasePrompt);
- elaboratedQuery = resp.getContent();
+ elaboratedQuery = chatClient.prompt()
+ .user(this.MultiPhrasePrompt)
+ .call()
+ .content();
logger.info("Multi-phrase queries: {}", elaboratedQuery);
}
@@ -265,28 +253,37 @@ public class AdvancedQueryRagSolver extends StepSolver {
prompt += "\n\n The query is: " + this.query;
prompt += "\n\n The output must be a single query and nothing else.";
- logger.info("Prompt for retrieving other documents: {}", prompt);
- OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt);
+ //logger.info("Prompt for retrieving other documents: {}", prompt);
+
+ String otherDocQuery = chatClient.prompt()
+ .user(prompt)
+ .call()
+ .content();
- String otherDocQuery = resp.getContent();
logger.info("Other document query: {}", otherDocQuery);
- Builder request_builder = SearchRequest.builder()
- .query(otherDocQuery)
- .topK(otherTopK)
- .similarityThreshold(0.8);
-
- if(this.rag_filter != null && !this.rag_filter.isEmpty()){
- request_builder.filterExpression(this.rag_filter);
- logger.info("Using Filter expression: {}", this.rag_filter);
- }
-
- SearchRequest request = request_builder.build();
- logger.info("Performing similarity search for other document query");
- otherDocs.addAll(this.vectorStore.similaritySearch(request));
- logger.info("Number of addictional documents: {}", otherDocs.size());
+ if(otherDocQuery != null && !otherDocQuery.isEmpty()){
+ logger.info("Performing similarity search for other document query: {}", otherDocQuery);
+ Builder request_builder = SearchRequest.builder()
+ .query(otherDocQuery)
+ .topK(otherTopK)
+ .similarityThreshold(0.8);
+
+ if(this.rag_filter != null && !this.rag_filter.isEmpty()){
+ request_builder.filterExpression(this.rag_filter);
+ logger.info("Using Filter expression: {}", this.rag_filter);
+ }
+
+ SearchRequest request = request_builder.build();
+ logger.info("Performing similarity search for other document query");
+ otherDocs.addAll(this.vectorStore.similaritySearch(request));
+ logger.info("Number of addictional documents: {}", otherDocs.size());
- docs.addAll(otherDocs);
+ docs.addAll(otherDocs);
+ logger.info("Number of documents after other document retrieval: {}", docs.size());
+ } else {
+ logger.info("No other document query generated, skipping other document retrieval.");
+ }
}
docs = docs.stream()
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
index dbf05c9..e58e507 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java
@@ -17,12 +17,6 @@ public class BasicQueryRagSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
- // PARAMETERS
- // rag_input : the field of the sharedMap to use as rag input
- // rag_filter : the field of the sharedMap to use as rag filter
- // rag_topk : the number of topk results to return
- // rag_threshold : the similarity threshold to use
- // rag_output_variable : the field of the sharedMap to store the output
private String query;
private String rag_filter;
private int topk;
@@ -71,14 +65,11 @@ public class BasicQueryRagSolver extends StepSolver {
loadParameters();
logParameters();
-
-
Builder request_builder = SearchRequest.builder()
.query(this.query)
.topK(this.topk)
.similarityThreshold(this.threshold);
-
-
+ logger.info("Performing similarity search with query: " + this.query);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java
index 478a383..08eb9fb 100644
--- a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java
+++ b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java
@@ -1,30 +1,22 @@
package com.olympus.hermione.stepSolvers;
-import java.io.FileInputStream;
-import java.io.FileNotFoundException;
-import java.nio.file.Files;
+import java.io.File;
import java.util.ArrayList;
-import java.util.Base64;
import java.util.List;
import java.util.Optional;
import org.apache.tika.Tika;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
+import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
-import org.springframework.core.io.ClassPathResource;
-import org.springframework.core.io.InputStreamResource;
-import org.springframework.util.MimeTypeUtils;
-import org.springframework.ai.model.Media;
+
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
-import java.io.File;
-import java.io.FileInputStream;
-
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
@@ -48,7 +40,6 @@ public class SummarizeDocSolver extends StepSolver {
private boolean isChunked = false;
private Double chunk_size_token_calc;
- // private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class);
@@ -78,143 +69,215 @@ public class SummarizeDocSolver extends StepSolver {
this.qai_system_prompt_template_formatter = attributeParser
.parse((String) this.step.getAttributes().get("qai_system_prompt_template_formatter"));
+
+ logger.info("Parametri caricati: file path={}, encoding={}, max_chunk_size_token={}, output_variable={}",
+ this.qai_path_file, this.encoding_name, this.max_chunk_size_token, this.qai_output_variable);
}
@Override
- public ScenarioExecution solveStep() {
+ public ScenarioExecution solveStep() throws Exception {
+ logger.info("Iniziando solveStep per: {}", this.step.getName());
System.out.println("Solving step: " + this.step.getName());
- try {
- this.scenarioExecution.setCurrentStepId(this.step.getStepId());
- loadParameters();
+
+ this.scenarioExecution.setCurrentStepId(this.step.getStepId());
+ loadParameters();
- // Estrazione del testo dal file
- File file = new File(this.qai_path_file);
- Tika tika = new Tika();
- tika.setMaxStringLength(-1);
- String text = tika.parseToString(file);
-
- // add libreria tokenizer
- // Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo)
- EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
- encoding = registry.getEncoding(this.encoding_name);
-
- tokenCount = encoding.get().countTokens(text);
- logger.info("token count input: " + tokenCount);
-
- // Calcolo della dimensione del chunk in token
- chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length);
- if(chunk_size_token_calc>max_chunk_size_token){
- chunk_size_token_calc = (double)max_chunk_size_token;
-
- }
- // **Fase di Summarization**
- String summarizedText = text;
-
- if(tokenCount > max_input_token_not_to_be_chunked){
- summarizedText = summarize(text);
- }
-
- // Creazione dei messaggi per il modello AI
- Message userMessage = new UserMessage(summarizedText);
- Message systemMessage = null;
-
- int tokenCountSummary = encoding.get().countTokens(summarizedText);
-
- if(isChunked && tokenCountSummary < max_output_token){
- systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter);
- logger.info("template formatter: " + this.qai_system_prompt_template_formatter);
- }else{
- //here
- systemMessage = new SystemMessage(this.qai_system_prompt_template);
- logger.info("template: " + this.qai_system_prompt_template);
- }
-
- CallResponseSpec resp = chatClient.prompt()
- .messages(userMessage, systemMessage)
- .advisors(advisor -> advisor
- .param("chat_memory_conversation_id",
- this.scenarioExecution.getId() + this.qai_custom_memory_id)
- .param("chat_memory_response_size", 100))
- .call();
-
- String output = resp.content();
-
- // Gestione del conteggio token usati
- Usage usage = resp.chatResponse().getMetadata().getUsage();
- if (usage != null) {
- Integer usedTokens = usage.getTotalTokens();
- this.scenarioExecution.setUsedTokens(usedTokens);
- } else {
- logger.info("Token usage information is not available.");
- }
-
- tokenCount = encoding.get().countTokens(output);
- logger.info("token count output: " + tokenCount);
-
- // Salvataggio dell'output nel contesto di esecuzione
- this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
- this.scenarioExecution.setNextStepId(this.step.getNextStepId());
-
- } catch (Exception e) {
- logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage());
- this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error");
+ // Estrazione del testo dal file
+ logger.info("Estrazione del testo dal file: {}", this.qai_path_file);
+ File file = new File(this.qai_path_file);
+ if (!file.exists()) {
+ logger.error("File non trovato: {}", this.qai_path_file);
+ throw new RuntimeException("File non trovato: " + this.qai_path_file);
}
+
+ logger.info("Dimensione file: {} bytes", file.length());
+ Tika tika = new Tika();
+ tika.setMaxStringLength(-1);
+ logger.info("Iniziando parsing del file con Tika");
+ String text ="";
+ try
+ {
+ text = tika.parseToString(file);
+ } catch (Exception e) {
+ logger.error("Errore durante l'estrazione del testo dal file: {}", e.getMessage());
+ throw new RuntimeException("Errore durante l'estrazione del testo dal file: " + e.getMessage());
+ }
+
+ logger.info("Estrazione testo completata, lunghezza: {} caratteri", text.length());
+
+ logger.info("Inizializzazione encoding: {}", this.encoding_name);
+ EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
+ encoding = registry.getEncoding(this.encoding_name);
+ if (!encoding.isPresent()) {
+ logger.error("Encoding non trovato: {}", this.encoding_name);
+ throw new RuntimeException("Encoding non trovato: " + this.encoding_name);
+ }
+
+ logger.info("Conteggio token del testo estratto");
+ tokenCount = encoding.get().countTokens(text);
+ logger.info("token count input: " + tokenCount);
+
+ chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length);
+ logger.info("Calcolo dimensione chunk: {} token * {} = {} token",
+ tokenCount, percent_chunk_length, chunk_size_token_calc);
+
+ if(chunk_size_token_calc > max_chunk_size_token){
+ logger.info("Dimensione chunk calcolata ({}) supera il limite massimo, impostata a: {}",
+ chunk_size_token_calc, max_chunk_size_token);
+ chunk_size_token_calc = (double)max_chunk_size_token;
+ }
+ String summarizedText = text;
+
+ if(tokenCount > max_input_token_not_to_be_chunked){
+ logger.info("Il testo supera il limite massimo di token non chunked ({} > {}), procedendo con la summarizzazione",
+ tokenCount, max_input_token_not_to_be_chunked);
+ summarizedText = summarize(text);
+ } else {
+ logger.info("Il testo non supera il limite massimo di token non chunked ({} <= {}), non necessaria summarizzazione",
+ tokenCount, max_input_token_not_to_be_chunked);
+ }
+
+ logger.info("Preparazione messaggi per il modello");
+ Message userMessage = new UserMessage(summarizedText);
+ Message systemMessage = null;
+
+ int tokenCountSummary = encoding.get().countTokens(summarizedText);
+ logger.info("Token count del testo summarizzato: {}", tokenCountSummary);
+
+ if(isChunked && tokenCountSummary < max_output_token){
+ logger.info("Testo è stato chunked e il numero di token è sotto il limite massimo, " +
+ "usando template formatter");
+ systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter);
+ logger.info("template formatter: " + this.qai_system_prompt_template_formatter);
+ }else{
+ logger.info("Usando template standard");
+ systemMessage = new SystemMessage(this.qai_system_prompt_template);
+ logger.info("template: " + this.qai_system_prompt_template);
+ }
+
+ logger.info("Invio richiesta al modello con conversation ID: {}",
+ this.scenarioExecution.getId() + this.qai_custom_memory_id);
+
+ CallResponseSpec resp = chatClient.prompt()
+ .messages(userMessage, systemMessage)
+ .advisors(advisor -> advisor
+ .param(ChatMemory.CONVERSATION_ID,
+ this.scenarioExecution.getId() + this.qai_custom_memory_id))
+ .call();
+
+ String output = resp.content();
+ logger.info("Risposta ricevuta dal modello");
+
+ // Gestione del conteggio token usati
+ try {
+ var chatResponse = resp.chatResponse();
+ if (chatResponse != null) {
+ var metadata = chatResponse.getMetadata();
+ if (metadata != null) {
+ Usage usage = metadata.getUsage();
+ if (usage != null) {
+ Integer usedTokens = usage.getTotalTokens();
+ this.scenarioExecution.setUsedTokens(usedTokens);
+ logger.info("Token utilizzati: {}", usedTokens);
+ } else {
+ logger.info("Token usage information is not available.");
+ }
+ } else {
+ logger.warn("Metadata non disponibili nella risposta");
+ }
+ } else {
+ logger.warn("Chat response non disponibile");
+ }
+ } catch (Exception e) {
+ logger.warn("Errore nel recuperare le informazioni di utilizzo dei token: {}", e.getMessage());
+ }
+
+ tokenCount = encoding.get().countTokens(output);
+ logger.info("token count output: " + tokenCount);
+
+ // Salvataggio dell'output nel contesto di esecuzione
+ this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
+ logger.info("Output salvato nella variabile: {}", this.qai_output_variable);
+ this.scenarioExecution.setNextStepId(this.step.getNextStepId());
+ logger.info("Step completato con successo, prossimo step: {}", this.step.getNextStepId());
return this.scenarioExecution;
}
private String summarize(String text) {
+ logger.info("Iniziando il processo di summarizzazione, lunghezza testo: {} caratteri", text.length());
logger.info("length: " + text.length());
Double chunk_size_text;
tokenCount = encoding.get().countTokens(text);
+ logger.info("Token count del testo: {}", tokenCount);
int textLengthPlus = (int) (text.length() * 1.1);
int tokenCountPlus = (int) (tokenCount * 1.1);
Double ratio = Math.floor((tokenCountPlus / chunk_size_token_calc) + 1);
//Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1);
+ logger.info("Ratio calcolato: {}, chunk_size_token_calc: {}", ratio, chunk_size_token_calc);
if (ratio == 1) {
+ logger.info("Ratio è 1, il testo non verrà suddiviso in chunk");
return text;
} else {
chunk_size_text = Math.ceil(textLengthPlus / ratio);
+ logger.info("Il testo verrà suddiviso in chunk di dimensione: {} caratteri", chunk_size_text.intValue());
}
// Suddividere il testo in chunk
List chunks = chunkText(text, chunk_size_text.intValue());
+ logger.info("Testo suddiviso in {} chunk", chunks.size());
List summarizedChunks = new ArrayList<>();
// Riassumere ogni chunk singolarmente
+ int chunkCounter = 0;
for (String chunk : chunks) {
+ chunkCounter++;
+ logger.info("Elaborazione chunk {}/{} - lunghezza: {} caratteri", chunkCounter, chunks.size(), chunk.length());
logger.info("chunk length: " + chunk.length());
- summarizedChunks.add(summarizeChunk(chunk));
+ String summarizedChunk = summarizeChunk(chunk);
+ logger.info("Chunk {}/{} riassunto con successo", chunkCounter, chunks.size());
+ summarizedChunks.add(summarizedChunk);
}
isChunked = true;
// Unire i riassunti
String summarizedText = String.join(" ", summarizedChunks);
int tokenCountSummarizedText = encoding.get().countTokens(summarizedText);
+ logger.info("Testo riassunto completo: {} token, {} caratteri", tokenCountSummarizedText, summarizedText.length());
// Se il riassunto è ancora troppo lungo, applicare ricorsione
if (tokenCountSummarizedText > max_output_token) {
+ logger.info("Il testo riassunto è ancora troppo lungo ({}), applicando ricorsione", tokenCountSummarizedText);
return summarize(summarizedText);
} else {
+ logger.info("Summarizzazione completata con successo");
return summarizedText;
}
}
// Metodo per suddividere il testo in chunk
private List chunkText(String text, int chunkSize) {
+ logger.info("Suddivisione testo in chunk di dimensione: {} caratteri", chunkSize);
List chunks = new ArrayList<>();
int length = text.length();
for (int i = 0; i < length; i += chunkSize) {
- chunks.add(text.substring(i, Math.min(length, i + chunkSize)));
+ int end = Math.min(length, i + chunkSize);
+ chunks.add(text.substring(i, end));
+ logger.debug("Creato chunk {}/{} da carattere {} a {}", chunks.size(),
+ (int)Math.ceil((double)length/chunkSize), i, end);
}
+ logger.info("Testo suddiviso in {} chunk", chunks.size());
return chunks;
}
// Metodo di riassunto per singolo chunk
private String summarizeChunk(String chunk) {
+ logger.info("Iniziando summarizzazione di un chunk di {} caratteri", chunk.length());
+ int tokenCountChunk = encoding.get().countTokens(chunk);
+ logger.info("Token count del chunk: {}", tokenCountChunk);
Message chunkMessage = new UserMessage(chunk);
Message systemMessage = new SystemMessage(
@@ -222,18 +285,25 @@ public class SummarizeDocSolver extends StepSolver {
logger.info("template chunk: " + this.qai_system_prompt_template_chunk);
- CallResponseSpec resp = chatClient.prompt()
- .messages(chunkMessage, systemMessage)
- /*
- * .advisors(advisor -> advisor
- * .param("chat_memory_conversation_id",
- * this.scenarioExecution.getId() + this.qai_custom_memory_id)
- * .param("chat_memory_response_size", 100))
- */
- .call();
-
- return resp.content();
+ try {
+ CallResponseSpec resp = chatClient.prompt()
+ .messages(chunkMessage, systemMessage)
+ .call();
+ String result = resp.content();
+ if (result != null) {
+ int tokenCountResult = encoding.get().countTokens(result);
+ logger.info("Chunk riassunto con successo: {} token, {} caratteri",
+ tokenCountResult, result.length());
+ return result;
+ } else {
+ logger.warn("Risposta nulla ricevuta dal modello");
+ return "";
+ }
+ } catch (Exception e) {
+ logger.error("Errore durante la summarizzazione del chunk: {}", e.getMessage());
+ return "ERRORE NELLA SUMMARIZZAZIONE: " + e.getMessage();
+ }
}
}
diff --git a/src/main/java/com/olympus/hermione/tools/DocumentationTool.java b/src/main/java/com/olympus/hermione/tools/DocumentationTool.java
new file mode 100644
index 0000000..49daa01
--- /dev/null
+++ b/src/main/java/com/olympus/hermione/tools/DocumentationTool.java
@@ -0,0 +1,66 @@
+package com.olympus.hermione.tools;
+
+import java.util.List;
+import java.util.logging.Logger;
+
+import org.springframework.ai.document.Document;
+import org.springframework.ai.tool.annotation.Tool;
+import org.springframework.ai.vectorstore.SearchRequest;
+import org.springframework.ai.vectorstore.VectorStore;
+
+
+
+
+public class DocumentationTool {
+
+ Logger logger = Logger.getLogger(DocumentationTool.class.getName());
+
+ private String application;
+ private String project;
+ private VectorStore vectorStore;
+
+
+ public DocumentationTool(String application, String project, VectorStore vectorStore) {
+ this.application = application;
+ this.project = project;
+ this.vectorStore = vectorStore;
+ }
+
+ @Tool(description = "Find any functional information about the application. Can be used to answeer any questron with a specific focus on the application")
+ public String getDocumentation(String query) {
+
+ logger.info("[TOOL]LLM DocumentationTool Getting documentation for query: " + query);
+ logger.info("[TOOL]LLM DocumentationTool Getting documentation for project: " + this.project);
+ logger.info("[TOOL]LLM DocumentationTool Getting documentation for vectorStore: " + this.vectorStore);
+ String filterExpression = "'KsProjectName' == '"+ this.project +"' AND 'KsDoctype' == 'functional'";
+
+ logger.info("[TOOL]LLM DocumentationTool Getting documentation for filterExpression: " + filterExpression);
+
+
+ SearchRequest request = SearchRequest.builder()
+ .query(query)
+ .topK(3)
+ .similarityThreshold(0.7)
+ .filterExpression( filterExpression).build();
+
+
+ List docs = this.vectorStore.similaritySearch(request);
+
+ logger.info("Number of VDB retrieved documents: " + docs.size());
+
+ String docToolResponse = "";
+ for (Document doc : docs) {
+ docToolResponse += "-------USE THIS AS PART OF YOUR CONTEXT----------------------------------------- \n";
+
+ docToolResponse += doc.getText() + "------------------------------------------------------------------------------------------ \n";
+ }
+
+
+
+
+ return docToolResponse;
+ }
+
+
+
+}
diff --git a/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java
index 99168e6..354fef1 100644
--- a/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java
+++ b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java
@@ -5,6 +5,7 @@ import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
@@ -19,7 +20,7 @@ public class RAGDocumentRanker {
Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class);
- public List rankDocuments(List documents, String query, int threshold, OlympusChatClient chatClient) {
+ public List rankDocuments(List documents, String query, int threshold, ChatClient chatClient) {
List rankedDocuments = new ArrayList<>();
@@ -30,8 +31,10 @@ public class RAGDocumentRanker {
rankingPrompt +="Chunk: " + document.getText();
rankingPrompt +="Answer with a number between 1 and 10 and nothing else.";
- OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt);
- String rank = resp.getContent().substring(0, 1);
+ String rank = chatClient.prompt()
+ .user(rankingPrompt)
+ .call()
+ .content().substring(0, 1);
int rankInt = Integer.parseInt(rank);
logger.info("Rank: " + rankInt);
diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties
index 0303374..337bd92 100644
--- a/src/main/resources/application.properties
+++ b/src/main/resources/application.properties
@@ -56,9 +56,9 @@ eureka.instance.preferIpAddress: true
hermione.fe.url = http://localhost:5173
-spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
+spring.ai.vectorstore.chroma.client.host=http://128.251.239.194
spring.ai.vectorstore.chroma.client.port=8000
-spring.ai.vectorstore.chroma.client.key-token=tKAJfN1Yv5lP7pKorJHGfHMQhNEcM9uu
+spring.ai.vectorstore.chroma.client.key-token=BxZWXFXC4UMSxamf5xP5SioGIg3FPfP7
spring.ai.vectorstore.chroma.initialize-schema=true
spring.ai.vectorstore.chroma.collection-name=olympus