From 49d935d07d40fee86f4a1107694cac8c1a082d7c Mon Sep 17 00:00:00 2001 From: Andrea Terzani Date: Thu, 26 Jun 2025 17:34:54 +0200 Subject: [PATCH] first iteration with scenario working --- pom.xml | 46 ++- .../hermione/config/AzureAIConfig.java | 26 -- .../hermione/config/HermioneConfig.java | 14 +- .../hermione/config/VectorStoreConfig.java | 50 ---- .../services/ScenarioExecutionService.java | 116 +++----- .../stepSolvers/AdvancedAIPromptSolver.java | 58 ++-- .../stepSolvers/AdvancedQueryRagSolver.java | 77 +++-- .../stepSolvers/BasicQueryRagSolver.java | 11 +- .../stepSolvers/SummarizeDocSolver.java | 270 +++++++++++------- .../hermione/tools/DocumentationTool.java | 66 +++++ .../hermione/utility/RAGDocumentRanker.java | 9 +- src/main/resources/application.properties | 4 +- 12 files changed, 365 insertions(+), 382 deletions(-) delete mode 100644 src/main/java/com/olympus/hermione/config/AzureAIConfig.java delete mode 100644 src/main/java/com/olympus/hermione/config/VectorStoreConfig.java create mode 100644 src/main/java/com/olympus/hermione/tools/DocumentationTool.java diff --git a/pom.xml b/pom.xml index 970ccb1..735ac9d 100644 --- a/pom.xml +++ b/pom.xml @@ -28,8 +28,7 @@ 21 - - 1.0.0-M6 + 1.0.0 2023.0.3 @@ -61,38 +60,26 @@ spring-boot-starter-oauth2-resource-server - - - org.json - json - 20250107 - + + org.json + json + 20250107 + - org.springframework.ai - spring-ai-openai-spring-boot-starter + spring-ai-starter-model-azure-openai + org.springframework.ai - spring-ai-azure-openai - - - org.springframework.ai - spring-ai-azure-store - - - org.springframework.ai - spring-ai-chroma-store + spring-ai-starter-vector-store-chroma + org.neo4j.driver neo4j-java-driver - 5.8.0 @@ -100,12 +87,12 @@ spring-boot-starter-test test - - org.projectlombok - lombok - 1.18.34 - - + + org.projectlombok + lombok + 1.18.38 + provided + io.jsonwebtoken jjwt-api @@ -153,7 +140,6 @@ com.google.code.gson gson - 2.8.8 diff --git a/src/main/java/com/olympus/hermione/config/AzureAIConfig.java b/src/main/java/com/olympus/hermione/config/AzureAIConfig.java deleted file mode 100644 index 81a4336..0000000 --- a/src/main/java/com/olympus/hermione/config/AzureAIConfig.java +++ /dev/null @@ -1,26 +0,0 @@ -package com.olympus.hermione.config; - -import java.time.Duration; - -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; -import org.springframework.web.client.RestTemplate; -import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAIClientBuilderCustomizer; -import com.azure.core.http.HttpClient; -import com.azure.core.util.HttpClientOptions; - - -@Configuration -public class AzureAIConfig { - - @Bean - public AzureOpenAIClientBuilderCustomizer responseTimeoutCustomizer() { - return openAiClientBuilder -> { - HttpClientOptions clientOptions = new HttpClientOptions() - .setResponseTimeout(Duration.ofMinutes(5)); - openAiClientBuilder.httpClient(HttpClient.createDefault(clientOptions)); - }; - } - - -} \ No newline at end of file diff --git a/src/main/java/com/olympus/hermione/config/HermioneConfig.java b/src/main/java/com/olympus/hermione/config/HermioneConfig.java index 5daa131..5619e7a 100644 --- a/src/main/java/com/olympus/hermione/config/HermioneConfig.java +++ b/src/main/java/com/olympus/hermione/config/HermioneConfig.java @@ -2,7 +2,6 @@ package com.olympus.hermione.config; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.openai.OpenAiEmbeddingModel; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Primary; @@ -11,24 +10,13 @@ import org.springframework.context.annotation.Primary; @Configuration public class HermioneConfig { - private final OpenAiEmbeddingModel openAiEmbeddingModel; private final AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel; - // Autowiring beans from both libraries - public HermioneConfig(OpenAiEmbeddingModel myServiceLib1, - AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) { - this.openAiEmbeddingModel = myServiceLib1; + public HermioneConfig( AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) { this.azureOpenAiEmbeddingModel = azureOpenAiEmbeddingModel; } - // Re-declaring and marking one as primary - @Bean - public EmbeddingModel primaryMyService() { - return openAiEmbeddingModel; // Choose the one you want as primary - } - - // Optionally declare the other bean without primary @Bean @Primary public EmbeddingModel secondaryMyService() { diff --git a/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java b/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java deleted file mode 100644 index 83f4571..0000000 --- a/src/main/java/com/olympus/hermione/config/VectorStoreConfig.java +++ /dev/null @@ -1,50 +0,0 @@ -package com.olympus.hermione.config; - -import com.azure.core.credential.AzureKeyCredential; -import com.azure.search.documents.indexes.SearchIndexClient; -import com.azure.search.documents.indexes.SearchIndexClientBuilder; -import org.springframework.ai.embedding.EmbeddingModel; -import org.springframework.ai.vectorstore.VectorStore; -import org.springframework.ai.vectorstore.azure.AzureVectorStore; -import org.springframework.beans.factory.annotation.Qualifier; -import org.springframework.beans.factory.annotation.Value; -import org.springframework.context.annotation.Bean; -import org.springframework.context.annotation.Configuration; - -import java.util.ArrayList; -import java.util.List; - -@Configuration -public class VectorStoreConfig { - - - @Value("${spring.ai.vectorstore.azure.api-key}") - private String azureKey; - @Value("${spring.ai.vectorstore.azure.url}") - private String azureEndpoint; - @Value("${spring.ai.vectorstore.azure.initialize-schema}") - private boolean initSchema; - - @Bean - public SearchIndexClient searchIndexClient() { - return new SearchIndexClientBuilder().endpoint(azureEndpoint) - .credential(new AzureKeyCredential(azureKey)) - .buildClient(); - } - - @Bean - public VectorStore vectorStore(SearchIndexClient searchIndexClient, @Qualifier("azureOpenAiEmbeddingModel") EmbeddingModel embeddingModel) { - List fields = new ArrayList<>(); - - fields.add(AzureVectorStore.MetadataField.text("KsApplicationName")); - fields.add(AzureVectorStore.MetadataField.text("KsProjectName")); - fields.add(AzureVectorStore.MetadataField.text("KsDoctype")); - fields.add(AzureVectorStore.MetadataField.text("KsDocSource")); - fields.add(AzureVectorStore.MetadataField.text("KsFileSource")); - fields.add(AzureVectorStore.MetadataField.text("KsDocumentId")); - - //return new AzureVectorStore(searchIndexClient, embeddingModel,initSchema, fields); - return null; - } -} - diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index e70141c..2752554 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -24,11 +24,8 @@ import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor; import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor; import org.springframework.ai.chat.memory.ChatMemory; -import org.springframework.ai.chat.memory.InMemoryChatMemory; +import org.springframework.ai.chat.memory.MessageWindowChatMemory; import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.openai.OpenAiChatModel; -import org.springframework.ai.openai.OpenAiChatOptions; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; @@ -200,19 +197,16 @@ public class ScenarioExecutionService { if (scenario.isUseChatMemory()) { logger.info("Initializing chatClient with chat-memory advisor"); - ChatMemory chatMemory = new InMemoryChatMemory(); + ChatMemory chatMemory = MessageWindowChatMemory.builder().build(); + chatClient = ChatClient.builder(chatModel) - .defaultAdvisors( - new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor - new SimpleLoggerAdvisor()) + .defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build() ) .build(); } else { - logger.info("Initializing chatClient with simple logger advisor"); + logger.info("Initializing chatClient without advisor"); chatClient = ChatClient.builder(chatModel) - .defaultAdvisors( - new SimpleLoggerAdvisor()) .build(); } @@ -437,81 +431,46 @@ public class ScenarioExecutionService { private ChatModel createChatModel(AiModel aiModel) { switch (aiModel.getApiProvider()) { case "AzureOpenAI": - OpenAIClientBuilder openAIClient = new OpenAIClientBuilder() - .credential(new AzureKeyCredential(aiModel.getApiKey())) - .endpoint(aiModel.getEndpoint()); + - AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder() - .deploymentName(aiModel.getModel()) - .maxTokens(aiModel.getMaxTokens()) - .temperature(Double.valueOf(aiModel.getTemperature())) - .build(); + OpenAIClientBuilder openAIClientBuilder = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(aiModel.getApiKey())) + .endpoint(aiModel.getEndpoint()); - AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions); + // Crea un builder con opzioni base + AzureOpenAiChatOptions.Builder optionsBuilder = AzureOpenAiChatOptions.builder() + .deploymentName(aiModel.getModel()); + + // Aggiungi maxTokens solo se non è nullo + if (aiModel.getMaxTokens() != null) { + optionsBuilder.maxTokens(aiModel.getMaxTokens()); + logger.info("Setting maxTokens: " + aiModel.getMaxTokens()); + } + + // Aggiungi temperature solo se non è nullo + if (aiModel.getTemperature() != null) { + optionsBuilder.temperature(Double.valueOf(aiModel.getTemperature())); + logger.info("Setting temperature: " + aiModel.getTemperature()); + } + + AzureOpenAiChatOptions openAIChatOptions = optionsBuilder.build(); + + var chatModel = AzureOpenAiChatModel.builder() + .openAIClientBuilder(openAIClientBuilder) + .defaultOptions(openAIChatOptions) + .build(); + logger.info("AI model used: " + aiModel.getModel()); - return azureOpenaichatModel; + return chatModel; - case "OpenAI": - OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey()); - OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder() - .model(aiModel.getModel()) - .temperature(Double.valueOf(aiModel.getTemperature())) - .maxTokens(aiModel.getMaxTokens()) - .build(); - OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); - logger.info("AI model used: " + aiModel.getModel()); - return openaichatModel; - - case "GoogleGemini": - OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder() - .credential(new AzureKeyCredential(aiModel.getApiKey())) - .endpoint(aiModel.getEndpoint()); - - AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder() - .deploymentName(aiModel.getModel()) - .maxTokens(aiModel.getMaxTokens()) - .temperature(Double.valueOf(aiModel.getTemperature())) - .build(); - - AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2, - openAIChatOptions2); - - logger.info("AI model used : " + aiModel.getModel()); - return azureOpenaichatModel2; default: throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName()); } } - /* - * public List getListExecutionScenario(){ - * logger.info("getListProjectByUser function:"); - * - * List lstScenarioExecution = null; - * - * User principal = (User) - * SecurityContextHolder.getContext().getAuthentication().getPrincipal(); - * - * if(principal.getSelectedApplication()!=null){ - * lstScenarioExecution = - * scenarioExecutionRepository.getFromProjectAndAPP(principal.getId(), - * principal.getSelectedProject().getInternal_name(), - * principal.getSelectedApplication().getInternal_name(), - * Sort.by(Sort.Direction.DESC, "startDate")); - * - * }else{ - * lstScenarioExecution = - * scenarioExecutionRepository.getFromProject(principal.getId(), - * principal.getSelectedProject().getInternal_name(), - * Sort.by(Sort.Direction.DESC, "startDate")); - * } - * - * return lstScenarioExecution; - * - * } - */ + public Page getListExecutionScenarioOptional(int page, int size, String scenarioName, String executedBy, LocalDate startDate) { @@ -582,12 +541,7 @@ public class ScenarioExecutionService { criteriaList.add(Criteria.where("execSharedMap.user_input.selected_application") .is(principal.getSelectedApplication().getInternal_name())); } - // here: - // filtri - // here: - // here: vedi perchè non funziona link di view - // here: da sistemare filtro data a backend e i vari equals - // here: add ordinamento + for (Map.Entry entry : filters.entrySet()) { String key = entry.getKey(); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java index 3534d08..e6bcab9 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java @@ -1,21 +1,18 @@ package com.olympus.hermione.stepSolvers; -import ch.qos.logback.classic.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.olympus.hermione.dto.aientity.CiaOutputEntity; import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.utility.AttributeParser; -import java.util.ArrayList; -import java.util.List; - -import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.metadata.Usage; +import ch.qos.logback.classic.Logger; public class AdvancedAIPromptSolver extends StepSolver { @@ -61,7 +58,7 @@ public class AdvancedAIPromptSolver extends StepSolver { this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id"); logger.info("qai_custom_memory_id: {}", this.qai_custom_memory_id); - //TODO: Add memory ID attribute to have the possibility of multiple conversations + this.qai_available_tools = (String) this.step.getAttributes().get("qai_available_tools"); logger.info("qai_available_tools: {}", this.qai_available_tools); } @@ -82,12 +79,22 @@ public class AdvancedAIPromptSolver extends StepSolver { logger.info("Available tools specified: {}", this.qai_available_tools); for(String tool: this.qai_available_tools.split(",")){ logger.info("Tool: {}", tool); - if(tool.equals("SourceCodeTool")){ - logger.info("Instantiating SourceCodeTool"); - tools = new com.olympus.hermione.tools.SourceCodeTool(discoveryClient); - } else { - logger.warn("Unknown tool specified: {}", tool); + + switch(tool.trim()){ + case "DocumentationTool" -> { + logger.info("Instantiating DocumentationTool"); + tools = new com.olympus.hermione.tools.DocumentationTool( + this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_application"), + this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_project"), + vectorStore); + } + + case "SourceCodeTool" -> { + logger.info("Instantiating SourceCodeTool"); + tools = new com.olympus.hermione.tools.SourceCodeTool(this.discoveryClient); + } } + } } else { logger.info("No tools specified or available"); @@ -98,9 +105,7 @@ public class AdvancedAIPromptSolver extends StepSolver { resp = chatClient.prompt() .user(this.qai_user_input) .system(this.qai_system_prompt_template) - .advisors(advisor -> advisor - .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) - .param("chat_memory_response_size", 100)) + .advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id)) .call(); }else{ logger.info("Calling chatClient with tools: {}", tools.getClass().getName()); @@ -108,12 +113,11 @@ public class AdvancedAIPromptSolver extends StepSolver { .user(this.qai_user_input) .system(this.qai_system_prompt_template) .advisors(advisor -> advisor - .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) - .param("chat_memory_response_size", 100)) + .param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id)) .tools(tools) .call(); } - + ChatResponse response = null; if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){ logger.info("Output is of type CiaOutputEntity"); @@ -129,13 +133,13 @@ public class AdvancedAIPromptSolver extends StepSolver { } }else{ logger.info("Output is of type String"); - String output = resp.content(); - logger.debug("String response: {}", output); - this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + response = resp.chatResponse(); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, response.getResult().getOutput().getText()); } - Usage usage = resp.chatResponse().getMetadata().getUsage(); - if (usage != null) { + + if (response != null && response.getMetadata() != null) { + Usage usage = response.getMetadata().getUsage(); Integer usedTokens = usage.getTotalTokens(); logger.info("Used tokens: {}", usedTokens); this.scenarioExecution.setUsedTokens(usedTokens); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java index 4eb9e7b..40a24f4 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedQueryRagSolver.java @@ -158,42 +158,30 @@ public class AdvancedQueryRagSolver extends StepSolver { } - @Override + @Override public ScenarioExecution solveStep(){ logger.info("Solving step: {}", this.step.getName()); loadParameters(); logParameters(); String elaboratedQuery = this.query; - OlympusChatClient chatClient = null; - - String apiProvider = this.scenarioExecution.getScenario().getAiModel().getApiProvider(); - logger.info("AI Model API Provider: {}", apiProvider); - if (apiProvider.equals("AzureOpenAI")) { - logger.info("Using AzureOpenApiChatClient"); - chatClient = new AzureOpenApiChatClient(); - } - if (apiProvider.equals("GoogleGemini")) { - logger.info("Using GoogleGeminiChatClient"); - chatClient = new GoogleGeminiChatClient(); - } - - logger.info("Initializing chat client with endpoint: {}", this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint()); - chatClient.init(this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(), - this.scenarioExecution.getScenario().getAiModel().getApiKey(), - this.scenarioExecution.getScenario().getAiModel().getMaxTokens()); if(this.enableRephrase){ logger.info("Rephrasing query with prompt: {}", this.rephrasePrompt); - OlympusChatClientResponse resp = chatClient.getChatCompletion(this.rephrasePrompt); - elaboratedQuery = resp.getContent(); + elaboratedQuery = chatClient.prompt() + .user(this.rephrasePrompt) + .call() + .content(); + logger.info("Rephrased query: {}", elaboratedQuery); } if(this.enableMultiPhrase){ logger.info("Generating multi-phrase queries with prompt: {}", this.MultiPhrasePrompt); - OlympusChatClientResponse resp = chatClient.getChatCompletion(this.MultiPhrasePrompt); - elaboratedQuery = resp.getContent(); + elaboratedQuery = chatClient.prompt() + .user(this.MultiPhrasePrompt) + .call() + .content(); logger.info("Multi-phrase queries: {}", elaboratedQuery); } @@ -265,28 +253,37 @@ public class AdvancedQueryRagSolver extends StepSolver { prompt += "\n\n The query is: " + this.query; prompt += "\n\n The output must be a single query and nothing else."; - logger.info("Prompt for retrieving other documents: {}", prompt); - OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt); + //logger.info("Prompt for retrieving other documents: {}", prompt); + + String otherDocQuery = chatClient.prompt() + .user(prompt) + .call() + .content(); - String otherDocQuery = resp.getContent(); logger.info("Other document query: {}", otherDocQuery); - Builder request_builder = SearchRequest.builder() - .query(otherDocQuery) - .topK(otherTopK) - .similarityThreshold(0.8); - - if(this.rag_filter != null && !this.rag_filter.isEmpty()){ - request_builder.filterExpression(this.rag_filter); - logger.info("Using Filter expression: {}", this.rag_filter); - } - - SearchRequest request = request_builder.build(); - logger.info("Performing similarity search for other document query"); - otherDocs.addAll(this.vectorStore.similaritySearch(request)); - logger.info("Number of addictional documents: {}", otherDocs.size()); + if(otherDocQuery != null && !otherDocQuery.isEmpty()){ + logger.info("Performing similarity search for other document query: {}", otherDocQuery); + Builder request_builder = SearchRequest.builder() + .query(otherDocQuery) + .topK(otherTopK) + .similarityThreshold(0.8); + + if(this.rag_filter != null && !this.rag_filter.isEmpty()){ + request_builder.filterExpression(this.rag_filter); + logger.info("Using Filter expression: {}", this.rag_filter); + } + + SearchRequest request = request_builder.build(); + logger.info("Performing similarity search for other document query"); + otherDocs.addAll(this.vectorStore.similaritySearch(request)); + logger.info("Number of addictional documents: {}", otherDocs.size()); - docs.addAll(otherDocs); + docs.addAll(otherDocs); + logger.info("Number of documents after other document retrieval: {}", docs.size()); + } else { + logger.info("No other document query generated, skipping other document retrieval."); + } } docs = docs.stream() diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java index dbf05c9..e58e507 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicQueryRagSolver.java @@ -17,12 +17,6 @@ public class BasicQueryRagSolver extends StepSolver { Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); - // PARAMETERS - // rag_input : the field of the sharedMap to use as rag input - // rag_filter : the field of the sharedMap to use as rag filter - // rag_topk : the number of topk results to return - // rag_threshold : the similarity threshold to use - // rag_output_variable : the field of the sharedMap to store the output private String query; private String rag_filter; private int topk; @@ -71,14 +65,11 @@ public class BasicQueryRagSolver extends StepSolver { loadParameters(); logParameters(); - - Builder request_builder = SearchRequest.builder() .query(this.query) .topK(this.topk) .similarityThreshold(this.threshold); - - + logger.info("Performing similarity search with query: " + this.query); if(this.rag_filter != null && !this.rag_filter.isEmpty()){ request_builder.filterExpression(this.rag_filter); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java index 478a383..08eb9fb 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java @@ -1,30 +1,22 @@ package com.olympus.hermione.stepSolvers; -import java.io.FileInputStream; -import java.io.FileNotFoundException; -import java.nio.file.Files; +import java.io.File; import java.util.ArrayList; -import java.util.Base64; import java.util.List; import java.util.Optional; import org.apache.tika.Tika; import org.slf4j.LoggerFactory; import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.memory.ChatMemory; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.metadata.Usage; -import org.springframework.core.io.ClassPathResource; -import org.springframework.core.io.InputStreamResource; -import org.springframework.util.MimeTypeUtils; -import org.springframework.ai.model.Media; + import com.knuddels.jtokkit.Encodings; import com.knuddels.jtokkit.api.Encoding; import com.knuddels.jtokkit.api.EncodingRegistry; -import java.io.File; -import java.io.FileInputStream; - import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.utility.AttributeParser; @@ -48,7 +40,6 @@ public class SummarizeDocSolver extends StepSolver { private boolean isChunked = false; private Double chunk_size_token_calc; - // private boolean qai_load_graph_schema=false; Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class); @@ -78,143 +69,215 @@ public class SummarizeDocSolver extends StepSolver { this.qai_system_prompt_template_formatter = attributeParser .parse((String) this.step.getAttributes().get("qai_system_prompt_template_formatter")); + + logger.info("Parametri caricati: file path={}, encoding={}, max_chunk_size_token={}, output_variable={}", + this.qai_path_file, this.encoding_name, this.max_chunk_size_token, this.qai_output_variable); } @Override - public ScenarioExecution solveStep() { + public ScenarioExecution solveStep() throws Exception { + logger.info("Iniziando solveStep per: {}", this.step.getName()); System.out.println("Solving step: " + this.step.getName()); - try { - this.scenarioExecution.setCurrentStepId(this.step.getStepId()); - loadParameters(); + + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); - // Estrazione del testo dal file - File file = new File(this.qai_path_file); - Tika tika = new Tika(); - tika.setMaxStringLength(-1); - String text = tika.parseToString(file); - - // add libreria tokenizer - // Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo) - EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); - encoding = registry.getEncoding(this.encoding_name); - - tokenCount = encoding.get().countTokens(text); - logger.info("token count input: " + tokenCount); - - // Calcolo della dimensione del chunk in token - chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length); - if(chunk_size_token_calc>max_chunk_size_token){ - chunk_size_token_calc = (double)max_chunk_size_token; - - } - // **Fase di Summarization** - String summarizedText = text; - - if(tokenCount > max_input_token_not_to_be_chunked){ - summarizedText = summarize(text); - } - - // Creazione dei messaggi per il modello AI - Message userMessage = new UserMessage(summarizedText); - Message systemMessage = null; - - int tokenCountSummary = encoding.get().countTokens(summarizedText); - - if(isChunked && tokenCountSummary < max_output_token){ - systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter); - logger.info("template formatter: " + this.qai_system_prompt_template_formatter); - }else{ - //here - systemMessage = new SystemMessage(this.qai_system_prompt_template); - logger.info("template: " + this.qai_system_prompt_template); - } - - CallResponseSpec resp = chatClient.prompt() - .messages(userMessage, systemMessage) - .advisors(advisor -> advisor - .param("chat_memory_conversation_id", - this.scenarioExecution.getId() + this.qai_custom_memory_id) - .param("chat_memory_response_size", 100)) - .call(); - - String output = resp.content(); - - // Gestione del conteggio token usati - Usage usage = resp.chatResponse().getMetadata().getUsage(); - if (usage != null) { - Integer usedTokens = usage.getTotalTokens(); - this.scenarioExecution.setUsedTokens(usedTokens); - } else { - logger.info("Token usage information is not available."); - } - - tokenCount = encoding.get().countTokens(output); - logger.info("token count output: " + tokenCount); - - // Salvataggio dell'output nel contesto di esecuzione - this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); - this.scenarioExecution.setNextStepId(this.step.getNextStepId()); - - } catch (Exception e) { - logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage()); - this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error"); + // Estrazione del testo dal file + logger.info("Estrazione del testo dal file: {}", this.qai_path_file); + File file = new File(this.qai_path_file); + if (!file.exists()) { + logger.error("File non trovato: {}", this.qai_path_file); + throw new RuntimeException("File non trovato: " + this.qai_path_file); } + + logger.info("Dimensione file: {} bytes", file.length()); + Tika tika = new Tika(); + tika.setMaxStringLength(-1); + logger.info("Iniziando parsing del file con Tika"); + String text =""; + try + { + text = tika.parseToString(file); + } catch (Exception e) { + logger.error("Errore durante l'estrazione del testo dal file: {}", e.getMessage()); + throw new RuntimeException("Errore durante l'estrazione del testo dal file: " + e.getMessage()); + } + + logger.info("Estrazione testo completata, lunghezza: {} caratteri", text.length()); + + logger.info("Inizializzazione encoding: {}", this.encoding_name); + EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); + encoding = registry.getEncoding(this.encoding_name); + if (!encoding.isPresent()) { + logger.error("Encoding non trovato: {}", this.encoding_name); + throw new RuntimeException("Encoding non trovato: " + this.encoding_name); + } + + logger.info("Conteggio token del testo estratto"); + tokenCount = encoding.get().countTokens(text); + logger.info("token count input: " + tokenCount); + + chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length); + logger.info("Calcolo dimensione chunk: {} token * {} = {} token", + tokenCount, percent_chunk_length, chunk_size_token_calc); + + if(chunk_size_token_calc > max_chunk_size_token){ + logger.info("Dimensione chunk calcolata ({}) supera il limite massimo, impostata a: {}", + chunk_size_token_calc, max_chunk_size_token); + chunk_size_token_calc = (double)max_chunk_size_token; + } + String summarizedText = text; + + if(tokenCount > max_input_token_not_to_be_chunked){ + logger.info("Il testo supera il limite massimo di token non chunked ({} > {}), procedendo con la summarizzazione", + tokenCount, max_input_token_not_to_be_chunked); + summarizedText = summarize(text); + } else { + logger.info("Il testo non supera il limite massimo di token non chunked ({} <= {}), non necessaria summarizzazione", + tokenCount, max_input_token_not_to_be_chunked); + } + + logger.info("Preparazione messaggi per il modello"); + Message userMessage = new UserMessage(summarizedText); + Message systemMessage = null; + + int tokenCountSummary = encoding.get().countTokens(summarizedText); + logger.info("Token count del testo summarizzato: {}", tokenCountSummary); + + if(isChunked && tokenCountSummary < max_output_token){ + logger.info("Testo è stato chunked e il numero di token è sotto il limite massimo, " + + "usando template formatter"); + systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter); + logger.info("template formatter: " + this.qai_system_prompt_template_formatter); + }else{ + logger.info("Usando template standard"); + systemMessage = new SystemMessage(this.qai_system_prompt_template); + logger.info("template: " + this.qai_system_prompt_template); + } + + logger.info("Invio richiesta al modello con conversation ID: {}", + this.scenarioExecution.getId() + this.qai_custom_memory_id); + + CallResponseSpec resp = chatClient.prompt() + .messages(userMessage, systemMessage) + .advisors(advisor -> advisor + .param(ChatMemory.CONVERSATION_ID, + this.scenarioExecution.getId() + this.qai_custom_memory_id)) + .call(); + + String output = resp.content(); + logger.info("Risposta ricevuta dal modello"); + + // Gestione del conteggio token usati + try { + var chatResponse = resp.chatResponse(); + if (chatResponse != null) { + var metadata = chatResponse.getMetadata(); + if (metadata != null) { + Usage usage = metadata.getUsage(); + if (usage != null) { + Integer usedTokens = usage.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + logger.info("Token utilizzati: {}", usedTokens); + } else { + logger.info("Token usage information is not available."); + } + } else { + logger.warn("Metadata non disponibili nella risposta"); + } + } else { + logger.warn("Chat response non disponibile"); + } + } catch (Exception e) { + logger.warn("Errore nel recuperare le informazioni di utilizzo dei token: {}", e.getMessage()); + } + + tokenCount = encoding.get().countTokens(output); + logger.info("token count output: " + tokenCount); + + // Salvataggio dell'output nel contesto di esecuzione + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + logger.info("Output salvato nella variabile: {}", this.qai_output_variable); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + logger.info("Step completato con successo, prossimo step: {}", this.step.getNextStepId()); return this.scenarioExecution; } private String summarize(String text) { + logger.info("Iniziando il processo di summarizzazione, lunghezza testo: {} caratteri", text.length()); logger.info("length: " + text.length()); Double chunk_size_text; tokenCount = encoding.get().countTokens(text); + logger.info("Token count del testo: {}", tokenCount); int textLengthPlus = (int) (text.length() * 1.1); int tokenCountPlus = (int) (tokenCount * 1.1); Double ratio = Math.floor((tokenCountPlus / chunk_size_token_calc) + 1); //Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); + logger.info("Ratio calcolato: {}, chunk_size_token_calc: {}", ratio, chunk_size_token_calc); if (ratio == 1) { + logger.info("Ratio è 1, il testo non verrà suddiviso in chunk"); return text; } else { chunk_size_text = Math.ceil(textLengthPlus / ratio); + logger.info("Il testo verrà suddiviso in chunk di dimensione: {} caratteri", chunk_size_text.intValue()); } // Suddividere il testo in chunk List chunks = chunkText(text, chunk_size_text.intValue()); + logger.info("Testo suddiviso in {} chunk", chunks.size()); List summarizedChunks = new ArrayList<>(); // Riassumere ogni chunk singolarmente + int chunkCounter = 0; for (String chunk : chunks) { + chunkCounter++; + logger.info("Elaborazione chunk {}/{} - lunghezza: {} caratteri", chunkCounter, chunks.size(), chunk.length()); logger.info("chunk length: " + chunk.length()); - summarizedChunks.add(summarizeChunk(chunk)); + String summarizedChunk = summarizeChunk(chunk); + logger.info("Chunk {}/{} riassunto con successo", chunkCounter, chunks.size()); + summarizedChunks.add(summarizedChunk); } isChunked = true; // Unire i riassunti String summarizedText = String.join(" ", summarizedChunks); int tokenCountSummarizedText = encoding.get().countTokens(summarizedText); + logger.info("Testo riassunto completo: {} token, {} caratteri", tokenCountSummarizedText, summarizedText.length()); // Se il riassunto è ancora troppo lungo, applicare ricorsione if (tokenCountSummarizedText > max_output_token) { + logger.info("Il testo riassunto è ancora troppo lungo ({}), applicando ricorsione", tokenCountSummarizedText); return summarize(summarizedText); } else { + logger.info("Summarizzazione completata con successo"); return summarizedText; } } // Metodo per suddividere il testo in chunk private List chunkText(String text, int chunkSize) { + logger.info("Suddivisione testo in chunk di dimensione: {} caratteri", chunkSize); List chunks = new ArrayList<>(); int length = text.length(); for (int i = 0; i < length; i += chunkSize) { - chunks.add(text.substring(i, Math.min(length, i + chunkSize))); + int end = Math.min(length, i + chunkSize); + chunks.add(text.substring(i, end)); + logger.debug("Creato chunk {}/{} da carattere {} a {}", chunks.size(), + (int)Math.ceil((double)length/chunkSize), i, end); } + logger.info("Testo suddiviso in {} chunk", chunks.size()); return chunks; } // Metodo di riassunto per singolo chunk private String summarizeChunk(String chunk) { + logger.info("Iniziando summarizzazione di un chunk di {} caratteri", chunk.length()); + int tokenCountChunk = encoding.get().countTokens(chunk); + logger.info("Token count del chunk: {}", tokenCountChunk); Message chunkMessage = new UserMessage(chunk); Message systemMessage = new SystemMessage( @@ -222,18 +285,25 @@ public class SummarizeDocSolver extends StepSolver { logger.info("template chunk: " + this.qai_system_prompt_template_chunk); - CallResponseSpec resp = chatClient.prompt() - .messages(chunkMessage, systemMessage) - /* - * .advisors(advisor -> advisor - * .param("chat_memory_conversation_id", - * this.scenarioExecution.getId() + this.qai_custom_memory_id) - * .param("chat_memory_response_size", 100)) - */ - .call(); - - return resp.content(); + try { + CallResponseSpec resp = chatClient.prompt() + .messages(chunkMessage, systemMessage) + .call(); + String result = resp.content(); + if (result != null) { + int tokenCountResult = encoding.get().countTokens(result); + logger.info("Chunk riassunto con successo: {} token, {} caratteri", + tokenCountResult, result.length()); + return result; + } else { + logger.warn("Risposta nulla ricevuta dal modello"); + return ""; + } + } catch (Exception e) { + logger.error("Errore durante la summarizzazione del chunk: {}", e.getMessage()); + return "ERRORE NELLA SUMMARIZZAZIONE: " + e.getMessage(); + } } } diff --git a/src/main/java/com/olympus/hermione/tools/DocumentationTool.java b/src/main/java/com/olympus/hermione/tools/DocumentationTool.java new file mode 100644 index 0000000..49daa01 --- /dev/null +++ b/src/main/java/com/olympus/hermione/tools/DocumentationTool.java @@ -0,0 +1,66 @@ +package com.olympus.hermione.tools; + +import java.util.List; +import java.util.logging.Logger; + +import org.springframework.ai.document.Document; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; + + + + +public class DocumentationTool { + + Logger logger = Logger.getLogger(DocumentationTool.class.getName()); + + private String application; + private String project; + private VectorStore vectorStore; + + + public DocumentationTool(String application, String project, VectorStore vectorStore) { + this.application = application; + this.project = project; + this.vectorStore = vectorStore; + } + + @Tool(description = "Find any functional information about the application. Can be used to answeer any questron with a specific focus on the application") + public String getDocumentation(String query) { + + logger.info("[TOOL]LLM DocumentationTool Getting documentation for query: " + query); + logger.info("[TOOL]LLM DocumentationTool Getting documentation for project: " + this.project); + logger.info("[TOOL]LLM DocumentationTool Getting documentation for vectorStore: " + this.vectorStore); + String filterExpression = "'KsProjectName' == '"+ this.project +"' AND 'KsDoctype' == 'functional'"; + + logger.info("[TOOL]LLM DocumentationTool Getting documentation for filterExpression: " + filterExpression); + + + SearchRequest request = SearchRequest.builder() + .query(query) + .topK(3) + .similarityThreshold(0.7) + .filterExpression( filterExpression).build(); + + + List docs = this.vectorStore.similaritySearch(request); + + logger.info("Number of VDB retrieved documents: " + docs.size()); + + String docToolResponse = ""; + for (Document doc : docs) { + docToolResponse += "-------USE THIS AS PART OF YOUR CONTEXT----------------------------------------- \n"; + + docToolResponse += doc.getText() + "------------------------------------------------------------------------------------------ \n"; + } + + + + + return docToolResponse; + } + + + +} diff --git a/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java index 99168e6..354fef1 100644 --- a/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java +++ b/src/main/java/com/olympus/hermione/utility/RAGDocumentRanker.java @@ -5,6 +5,7 @@ import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; import org.springframework.ai.document.Document; import org.springframework.stereotype.Service; @@ -19,7 +20,7 @@ public class RAGDocumentRanker { Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class); - public List rankDocuments(List documents, String query, int threshold, OlympusChatClient chatClient) { + public List rankDocuments(List documents, String query, int threshold, ChatClient chatClient) { List rankedDocuments = new ArrayList<>(); @@ -30,8 +31,10 @@ public class RAGDocumentRanker { rankingPrompt +="Chunk: " + document.getText(); rankingPrompt +="Answer with a number between 1 and 10 and nothing else."; - OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt); - String rank = resp.getContent().substring(0, 1); + String rank = chatClient.prompt() + .user(rankingPrompt) + .call() + .content().substring(0, 1); int rankInt = Integer.parseInt(rank); logger.info("Rank: " + rankInt); diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 0303374..337bd92 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -56,9 +56,9 @@ eureka.instance.preferIpAddress: true hermione.fe.url = http://localhost:5173 -spring.ai.vectorstore.chroma.client.host=http://108.142.74.161 +spring.ai.vectorstore.chroma.client.host=http://128.251.239.194 spring.ai.vectorstore.chroma.client.port=8000 -spring.ai.vectorstore.chroma.client.key-token=tKAJfN1Yv5lP7pKorJHGfHMQhNEcM9uu +spring.ai.vectorstore.chroma.client.key-token=BxZWXFXC4UMSxamf5xP5SioGIg3FPfP7 spring.ai.vectorstore.chroma.initialize-schema=true spring.ai.vectorstore.chroma.collection-name=olympus