Merged PR 181: upgrade Spring AI V1 e introduzxione document tool

This commit is contained in:
Andrea Terzani
2025-06-28 08:24:00 +00:00
12 changed files with 414 additions and 382 deletions

46
pom.xml
View File

@@ -28,8 +28,7 @@
</scm>
<properties>
<java.version>21</java.version>
<!--<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>-->
<spring-ai.version>1.0.0-M6</spring-ai.version>
<spring-ai.version>1.0.0</spring-ai.version>
<spring-cloud.version>2023.0.3</spring-cloud.version>
</properties>
<dependencies>
@@ -61,38 +60,26 @@
<artifactId>spring-boot-starter-oauth2-resource-server</artifactId>
</dependency>
<!-- https://mvnrepository.com/artifact/org.json/json -->
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20250107</version>
</dependency>
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20250107</version>
</dependency>
<!--<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-mongodb-atlas-store-spring-boot-starter</artifactId>
</dependency>-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
<artifactId>spring-ai-starter-model-azure-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-openai</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-store</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-chroma-store</artifactId>
<artifactId>spring-ai-starter-vector-store-chroma</artifactId>
</dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>
<version>5.8.0</version>
</dependency>
<dependency>
@@ -100,12 +87,12 @@
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.34</version>
</dependency>
<dependency>
<groupId>org.projectlombok</groupId>
<artifactId>lombok</artifactId>
<version>1.18.38</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>io.jsonwebtoken</groupId>
<artifactId>jjwt-api</artifactId>
@@ -153,7 +140,6 @@
<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>2.8.8</version>
</dependency>
<dependency>

View File

@@ -1,26 +0,0 @@
package com.olympus.hermione.config;
import java.time.Duration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAIClientBuilderCustomizer;
import com.azure.core.http.HttpClient;
import com.azure.core.util.HttpClientOptions;
@Configuration
public class AzureAIConfig {
@Bean
public AzureOpenAIClientBuilderCustomizer responseTimeoutCustomizer() {
return openAiClientBuilder -> {
HttpClientOptions clientOptions = new HttpClientOptions()
.setResponseTimeout(Duration.ofMinutes(5));
openAiClientBuilder.httpClient(HttpClient.createDefault(clientOptions));
};
}
}

View File

@@ -2,7 +2,6 @@ package com.olympus.hermione.config;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@@ -11,24 +10,13 @@ import org.springframework.context.annotation.Primary;
@Configuration
public class HermioneConfig {
private final OpenAiEmbeddingModel openAiEmbeddingModel;
private final AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel;
// Autowiring beans from both libraries
public HermioneConfig(OpenAiEmbeddingModel myServiceLib1,
AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) {
this.openAiEmbeddingModel = myServiceLib1;
public HermioneConfig( AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel) {
this.azureOpenAiEmbeddingModel = azureOpenAiEmbeddingModel;
}
// Re-declaring and marking one as primary
@Bean
public EmbeddingModel primaryMyService() {
return openAiEmbeddingModel; // Choose the one you want as primary
}
// Optionally declare the other bean without primary
@Bean
@Primary
public EmbeddingModel secondaryMyService() {

View File

@@ -1,50 +0,0 @@
package com.olympus.hermione.config;
import com.azure.core.credential.AzureKeyCredential;
import com.azure.search.documents.indexes.SearchIndexClient;
import com.azure.search.documents.indexes.SearchIndexClientBuilder;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.ai.vectorstore.azure.AzureVectorStore;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import java.util.ArrayList;
import java.util.List;
@Configuration
public class VectorStoreConfig {
@Value("${spring.ai.vectorstore.azure.api-key}")
private String azureKey;
@Value("${spring.ai.vectorstore.azure.url}")
private String azureEndpoint;
@Value("${spring.ai.vectorstore.azure.initialize-schema}")
private boolean initSchema;
@Bean
public SearchIndexClient searchIndexClient() {
return new SearchIndexClientBuilder().endpoint(azureEndpoint)
.credential(new AzureKeyCredential(azureKey))
.buildClient();
}
@Bean
public VectorStore vectorStore(SearchIndexClient searchIndexClient, @Qualifier("azureOpenAiEmbeddingModel") EmbeddingModel embeddingModel) {
List<AzureVectorStore.MetadataField> fields = new ArrayList<>();
fields.add(AzureVectorStore.MetadataField.text("KsApplicationName"));
fields.add(AzureVectorStore.MetadataField.text("KsProjectName"));
fields.add(AzureVectorStore.MetadataField.text("KsDoctype"));
fields.add(AzureVectorStore.MetadataField.text("KsDocSource"));
fields.add(AzureVectorStore.MetadataField.text("KsFileSource"));
fields.add(AzureVectorStore.MetadataField.text("KsDocumentId"));
//return new AzureVectorStore(searchIndexClient, embeddingModel,initSchema, fields);
return null;
}
}

View File

@@ -24,11 +24,8 @@ import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.memory.MessageWindowChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
@@ -200,19 +197,16 @@ public class ScenarioExecutionService {
if (scenario.isUseChatMemory()) {
logger.info("Initializing chatClient with chat-memory advisor");
ChatMemory chatMemory = new InMemoryChatMemory();
ChatMemory chatMemory = MessageWindowChatMemory.builder().build();
chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(
new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor
new SimpleLoggerAdvisor())
.defaultAdvisors(MessageChatMemoryAdvisor.builder(chatMemory).build() )
.build();
} else {
logger.info("Initializing chatClient with simple logger advisor");
logger.info("Initializing chatClient without advisor");
chatClient = ChatClient.builder(chatModel)
.defaultAdvisors(
new SimpleLoggerAdvisor())
.build();
}
@@ -437,81 +431,40 @@ public class ScenarioExecutionService {
private ChatModel createChatModel(AiModel aiModel) {
switch (aiModel.getApiProvider()) {
case "AzureOpenAI":
OpenAIClientBuilder openAIClient = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint());
AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder()
.deploymentName(aiModel.getModel())
.maxTokens(aiModel.getMaxTokens())
.temperature(Double.valueOf(aiModel.getTemperature()))
.build();
OpenAIClientBuilder openAIClientBuilder = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint());
AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions);
// Crea un builder con opzioni base
AzureOpenAiChatOptions.Builder optionsBuilder = AzureOpenAiChatOptions.builder()
.deploymentName(aiModel.getModel());
if (aiModel.isLegacy() ) {
logger.info("Using legacy Azure OpenAI model: " + aiModel.getModel());
optionsBuilder.temperature(aiModel.getTemperature().doubleValue());
optionsBuilder.maxTokens(aiModel.getMaxTokens());
}
AzureOpenAiChatOptions openAIChatOptions = optionsBuilder.build();
var chatModel = AzureOpenAiChatModel.builder()
.openAIClientBuilder(openAIClientBuilder)
.defaultOptions(openAIChatOptions)
.build();
logger.info("AI model used: " + aiModel.getModel());
return azureOpenaichatModel;
return chatModel;
case "OpenAI":
OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey());
OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder()
.model(aiModel.getModel())
.temperature(Double.valueOf(aiModel.getTemperature()))
.maxTokens(aiModel.getMaxTokens())
.build();
OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi, openAiChatOptions);
logger.info("AI model used: " + aiModel.getModel());
return openaichatModel;
case "GoogleGemini":
OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint());
AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder()
.deploymentName(aiModel.getModel())
.maxTokens(aiModel.getMaxTokens())
.temperature(Double.valueOf(aiModel.getTemperature()))
.build();
AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2,
openAIChatOptions2);
logger.info("AI model used : " + aiModel.getModel());
return azureOpenaichatModel2;
default:
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
}
}
/*
* public List<ScenarioExecution> getListExecutionScenario(){
* logger.info("getListProjectByUser function:");
*
* List<ScenarioExecution> lstScenarioExecution = null;
*
* User principal = (User)
* SecurityContextHolder.getContext().getAuthentication().getPrincipal();
*
* if(principal.getSelectedApplication()!=null){
* lstScenarioExecution =
* scenarioExecutionRepository.getFromProjectAndAPP(principal.getId(),
* principal.getSelectedProject().getInternal_name(),
* principal.getSelectedApplication().getInternal_name(),
* Sort.by(Sort.Direction.DESC, "startDate"));
*
* }else{
* lstScenarioExecution =
* scenarioExecutionRepository.getFromProject(principal.getId(),
* principal.getSelectedProject().getInternal_name(),
* Sort.by(Sort.Direction.DESC, "startDate"));
* }
*
* return lstScenarioExecution;
*
* }
*/
public Page<ScenarioExecution> getListExecutionScenarioOptional(int page, int size, String scenarioName,
String executedBy, LocalDate startDate) {
@@ -582,12 +535,7 @@ public class ScenarioExecutionService {
criteriaList.add(Criteria.where("execSharedMap.user_input.selected_application")
.is(principal.getSelectedApplication().getInternal_name()));
}
// here:
// filtri
// here:
// here: vedi perchè non funziona link di view
// here: da sistemare filtro data a backend e i vari equals
// here: add ordinamento
for (Map.Entry<String, Object> entry : filters.entrySet()) {
String key = entry.getKey();

View File

@@ -1,21 +1,18 @@
package com.olympus.hermione.stepSolvers;
import ch.qos.logback.classic.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.model.ChatResponse;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import java.util.ArrayList;
import java.util.List;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import ch.qos.logback.classic.Logger;
public class AdvancedAIPromptSolver extends StepSolver {
@@ -61,7 +58,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
logger.info("qai_custom_memory_id: {}", this.qai_custom_memory_id);
//TODO: Add memory ID attribute to have the possibility of multiple conversations
this.qai_available_tools = (String) this.step.getAttributes().get("qai_available_tools");
logger.info("qai_available_tools: {}", this.qai_available_tools);
}
@@ -82,12 +79,22 @@ public class AdvancedAIPromptSolver extends StepSolver {
logger.info("Available tools specified: {}", this.qai_available_tools);
for(String tool: this.qai_available_tools.split(",")){
logger.info("Tool: {}", tool);
if(tool.equals("SourceCodeTool")){
logger.info("Instantiating SourceCodeTool");
tools = new com.olympus.hermione.tools.SourceCodeTool(discoveryClient);
} else {
logger.warn("Unknown tool specified: {}", tool);
switch(tool.trim()){
case "DocumentationTool" -> {
logger.info("Instantiating DocumentationTool");
tools = new com.olympus.hermione.tools.DocumentationTool(
this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_application"),
this.scenarioExecution.getScenarioExecutionInput().getInputs().get("selected_project"),
vectorStore);
}
case "SourceCodeTool" -> {
logger.info("Instantiating SourceCodeTool");
tools = new com.olympus.hermione.tools.SourceCodeTool(this.discoveryClient);
}
}
}
} else {
logger.info("No tools specified or available");
@@ -98,9 +105,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
resp = chatClient.prompt()
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.advisors(advisor -> advisor.param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id))
.call();
}else{
logger.info("Calling chatClient with tools: {}", tools.getClass().getName());
@@ -108,12 +113,11 @@ public class AdvancedAIPromptSolver extends StepSolver {
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.param(ChatMemory.CONVERSATION_ID, this.scenarioExecution.getId()+this.qai_custom_memory_id))
.tools(tools)
.call();
}
ChatResponse response = null;
if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
logger.info("Output is of type CiaOutputEntity");
@@ -129,13 +133,13 @@ public class AdvancedAIPromptSolver extends StepSolver {
}
}else{
logger.info("Output is of type String");
String output = resp.content();
logger.debug("String response: {}", output);
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
response = resp.chatResponse();
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, response.getResult().getOutput().getText());
}
Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) {
if (response != null && response.getMetadata() != null) {
Usage usage = response.getMetadata().getUsage();
Integer usedTokens = usage.getTotalTokens();
logger.info("Used tokens: {}", usedTokens);
this.scenarioExecution.setUsedTokens(usedTokens);

View File

@@ -158,42 +158,30 @@ public class AdvancedQueryRagSolver extends StepSolver {
}
@Override
@Override
public ScenarioExecution solveStep(){
logger.info("Solving step: {}", this.step.getName());
loadParameters();
logParameters();
String elaboratedQuery = this.query;
OlympusChatClient chatClient = null;
String apiProvider = this.scenarioExecution.getScenario().getAiModel().getApiProvider();
logger.info("AI Model API Provider: {}", apiProvider);
if (apiProvider.equals("AzureOpenAI")) {
logger.info("Using AzureOpenApiChatClient");
chatClient = new AzureOpenApiChatClient();
}
if (apiProvider.equals("GoogleGemini")) {
logger.info("Using GoogleGeminiChatClient");
chatClient = new GoogleGeminiChatClient();
}
logger.info("Initializing chat client with endpoint: {}", this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint());
chatClient.init(this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(),
this.scenarioExecution.getScenario().getAiModel().getApiKey(),
this.scenarioExecution.getScenario().getAiModel().getMaxTokens());
if(this.enableRephrase){
logger.info("Rephrasing query with prompt: {}", this.rephrasePrompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(this.rephrasePrompt);
elaboratedQuery = resp.getContent();
elaboratedQuery = chatClient.prompt()
.user(this.rephrasePrompt)
.call()
.content();
logger.info("Rephrased query: {}", elaboratedQuery);
}
if(this.enableMultiPhrase){
logger.info("Generating multi-phrase queries with prompt: {}", this.MultiPhrasePrompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(this.MultiPhrasePrompt);
elaboratedQuery = resp.getContent();
elaboratedQuery = chatClient.prompt()
.user(this.MultiPhrasePrompt)
.call()
.content();
logger.info("Multi-phrase queries: {}", elaboratedQuery);
}
@@ -265,28 +253,37 @@ public class AdvancedQueryRagSolver extends StepSolver {
prompt += "\n\n The query is: " + this.query;
prompt += "\n\n The output must be a single query and nothing else.";
logger.info("Prompt for retrieving other documents: {}", prompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt);
//logger.info("Prompt for retrieving other documents: {}", prompt);
String otherDocQuery = chatClient.prompt()
.user(prompt)
.call()
.content();
String otherDocQuery = resp.getContent();
logger.info("Other document query: {}", otherDocQuery);
Builder request_builder = SearchRequest.builder()
.query(otherDocQuery)
.topK(otherTopK)
.similarityThreshold(0.8);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: {}", this.rag_filter);
}
SearchRequest request = request_builder.build();
logger.info("Performing similarity search for other document query");
otherDocs.addAll(this.vectorStore.similaritySearch(request));
logger.info("Number of addictional documents: {}", otherDocs.size());
if(otherDocQuery != null && !otherDocQuery.isEmpty()){
logger.info("Performing similarity search for other document query: {}", otherDocQuery);
Builder request_builder = SearchRequest.builder()
.query(otherDocQuery)
.topK(otherTopK)
.similarityThreshold(0.8);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: {}", this.rag_filter);
}
SearchRequest request = request_builder.build();
logger.info("Performing similarity search for other document query");
otherDocs.addAll(this.vectorStore.similaritySearch(request));
logger.info("Number of addictional documents: {}", otherDocs.size());
docs.addAll(otherDocs);
docs.addAll(otherDocs);
logger.info("Number of documents after other document retrieval: {}", docs.size());
} else {
logger.info("No other document query generated, skipping other document retrieval.");
}
}
docs = docs.stream()

View File

@@ -17,12 +17,6 @@ public class BasicQueryRagSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
// PARAMETERS
// rag_input : the field of the sharedMap to use as rag input
// rag_filter : the field of the sharedMap to use as rag filter
// rag_topk : the number of topk results to return
// rag_threshold : the similarity threshold to use
// rag_output_variable : the field of the sharedMap to store the output
private String query;
private String rag_filter;
private int topk;
@@ -71,14 +65,11 @@ public class BasicQueryRagSolver extends StepSolver {
loadParameters();
logParameters();
Builder request_builder = SearchRequest.builder()
.query(this.query)
.topK(this.topk)
.similarityThreshold(this.threshold);
logger.info("Performing similarity search with query: " + this.query);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);

View File

@@ -1,30 +1,22 @@
package com.olympus.hermione.stepSolvers;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.nio.file.Files;
import java.io.File;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import org.apache.tika.Tika;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.InputStreamResource;
import org.springframework.util.MimeTypeUtils;
import org.springframework.ai.model.Media;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import java.io.File;
import java.io.FileInputStream;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
@@ -48,7 +40,6 @@ public class SummarizeDocSolver extends StepSolver {
private boolean isChunked = false;
private Double chunk_size_token_calc;
// private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class);
@@ -78,143 +69,215 @@ public class SummarizeDocSolver extends StepSolver {
this.qai_system_prompt_template_formatter = attributeParser
.parse((String) this.step.getAttributes().get("qai_system_prompt_template_formatter"));
logger.info("Parametri caricati: file path={}, encoding={}, max_chunk_size_token={}, output_variable={}",
this.qai_path_file, this.encoding_name, this.max_chunk_size_token, this.qai_output_variable);
}
@Override
public ScenarioExecution solveStep() {
public ScenarioExecution solveStep() throws Exception {
logger.info("Iniziando solveStep per: {}", this.step.getName());
System.out.println("Solving step: " + this.step.getName());
try {
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
// Estrazione del testo dal file
File file = new File(this.qai_path_file);
Tika tika = new Tika();
tika.setMaxStringLength(-1);
String text = tika.parseToString(file);
// add libreria tokenizer
// Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo)
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
encoding = registry.getEncoding(this.encoding_name);
tokenCount = encoding.get().countTokens(text);
logger.info("token count input: " + tokenCount);
// Calcolo della dimensione del chunk in token
chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length);
if(chunk_size_token_calc>max_chunk_size_token){
chunk_size_token_calc = (double)max_chunk_size_token;
}
// **Fase di Summarization**
String summarizedText = text;
if(tokenCount > max_input_token_not_to_be_chunked){
summarizedText = summarize(text);
}
// Creazione dei messaggi per il modello AI
Message userMessage = new UserMessage(summarizedText);
Message systemMessage = null;
int tokenCountSummary = encoding.get().countTokens(summarizedText);
if(isChunked && tokenCountSummary < max_output_token){
systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter);
logger.info("template formatter: " + this.qai_system_prompt_template_formatter);
}else{
//here
systemMessage = new SystemMessage(this.qai_system_prompt_template);
logger.info("template: " + this.qai_system_prompt_template);
}
CallResponseSpec resp = chatClient.prompt()
.messages(userMessage, systemMessage)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id",
this.scenarioExecution.getId() + this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.call();
String output = resp.content();
// Gestione del conteggio token usati
Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) {
Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens);
} else {
logger.info("Token usage information is not available.");
}
tokenCount = encoding.get().countTokens(output);
logger.info("token count output: " + tokenCount);
// Salvataggio dell'output nel contesto di esecuzione
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
} catch (Exception e) {
logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage());
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error");
// Estrazione del testo dal file
logger.info("Estrazione del testo dal file: {}", this.qai_path_file);
File file = new File(this.qai_path_file);
if (!file.exists()) {
logger.error("File non trovato: {}", this.qai_path_file);
throw new RuntimeException("File non trovato: " + this.qai_path_file);
}
logger.info("Dimensione file: {} bytes", file.length());
Tika tika = new Tika();
tika.setMaxStringLength(-1);
logger.info("Iniziando parsing del file con Tika");
String text ="";
try
{
text = tika.parseToString(file);
} catch (Exception e) {
logger.error("Errore durante l'estrazione del testo dal file: {}", e.getMessage());
throw new RuntimeException("Errore durante l'estrazione del testo dal file: " + e.getMessage());
}
logger.info("Estrazione testo completata, lunghezza: {} caratteri", text.length());
logger.info("Inizializzazione encoding: {}", this.encoding_name);
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
encoding = registry.getEncoding(this.encoding_name);
if (!encoding.isPresent()) {
logger.error("Encoding non trovato: {}", this.encoding_name);
throw new RuntimeException("Encoding non trovato: " + this.encoding_name);
}
logger.info("Conteggio token del testo estratto");
tokenCount = encoding.get().countTokens(text);
logger.info("token count input: " + tokenCount);
chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length);
logger.info("Calcolo dimensione chunk: {} token * {} = {} token",
tokenCount, percent_chunk_length, chunk_size_token_calc);
if(chunk_size_token_calc > max_chunk_size_token){
logger.info("Dimensione chunk calcolata ({}) supera il limite massimo, impostata a: {}",
chunk_size_token_calc, max_chunk_size_token);
chunk_size_token_calc = (double)max_chunk_size_token;
}
String summarizedText = text;
if(tokenCount > max_input_token_not_to_be_chunked){
logger.info("Il testo supera il limite massimo di token non chunked ({} > {}), procedendo con la summarizzazione",
tokenCount, max_input_token_not_to_be_chunked);
summarizedText = summarize(text);
} else {
logger.info("Il testo non supera il limite massimo di token non chunked ({} <= {}), non necessaria summarizzazione",
tokenCount, max_input_token_not_to_be_chunked);
}
logger.info("Preparazione messaggi per il modello");
Message userMessage = new UserMessage(summarizedText);
Message systemMessage = null;
int tokenCountSummary = encoding.get().countTokens(summarizedText);
logger.info("Token count del testo summarizzato: {}", tokenCountSummary);
if(isChunked && tokenCountSummary < max_output_token){
logger.info("Testo è stato chunked e il numero di token è sotto il limite massimo, " +
"usando template formatter");
systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter);
logger.info("template formatter: " + this.qai_system_prompt_template_formatter);
}else{
logger.info("Usando template standard");
systemMessage = new SystemMessage(this.qai_system_prompt_template);
logger.info("template: " + this.qai_system_prompt_template);
}
logger.info("Invio richiesta al modello con conversation ID: {}",
this.scenarioExecution.getId() + this.qai_custom_memory_id);
CallResponseSpec resp = chatClient.prompt()
.messages(userMessage, systemMessage)
.advisors(advisor -> advisor
.param(ChatMemory.CONVERSATION_ID,
this.scenarioExecution.getId() + this.qai_custom_memory_id))
.call();
String output = resp.content();
logger.info("Risposta ricevuta dal modello");
// Gestione del conteggio token usati
try {
var chatResponse = resp.chatResponse();
if (chatResponse != null) {
var metadata = chatResponse.getMetadata();
if (metadata != null) {
Usage usage = metadata.getUsage();
if (usage != null) {
Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens);
logger.info("Token utilizzati: {}", usedTokens);
} else {
logger.info("Token usage information is not available.");
}
} else {
logger.warn("Metadata non disponibili nella risposta");
}
} else {
logger.warn("Chat response non disponibile");
}
} catch (Exception e) {
logger.warn("Errore nel recuperare le informazioni di utilizzo dei token: {}", e.getMessage());
}
tokenCount = encoding.get().countTokens(output);
logger.info("token count output: " + tokenCount);
// Salvataggio dell'output nel contesto di esecuzione
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
logger.info("Output salvato nella variabile: {}", this.qai_output_variable);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
logger.info("Step completato con successo, prossimo step: {}", this.step.getNextStepId());
return this.scenarioExecution;
}
private String summarize(String text) {
logger.info("Iniziando il processo di summarizzazione, lunghezza testo: {} caratteri", text.length());
logger.info("length: " + text.length());
Double chunk_size_text;
tokenCount = encoding.get().countTokens(text);
logger.info("Token count del testo: {}", tokenCount);
int textLengthPlus = (int) (text.length() * 1.1);
int tokenCountPlus = (int) (tokenCount * 1.1);
Double ratio = Math.floor((tokenCountPlus / chunk_size_token_calc) + 1);
//Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1);
logger.info("Ratio calcolato: {}, chunk_size_token_calc: {}", ratio, chunk_size_token_calc);
if (ratio == 1) {
logger.info("Ratio è 1, il testo non verrà suddiviso in chunk");
return text;
} else {
chunk_size_text = Math.ceil(textLengthPlus / ratio);
logger.info("Il testo verrà suddiviso in chunk di dimensione: {} caratteri", chunk_size_text.intValue());
}
// Suddividere il testo in chunk
List<String> chunks = chunkText(text, chunk_size_text.intValue());
logger.info("Testo suddiviso in {} chunk", chunks.size());
List<String> summarizedChunks = new ArrayList<>();
// Riassumere ogni chunk singolarmente
int chunkCounter = 0;
for (String chunk : chunks) {
chunkCounter++;
logger.info("Elaborazione chunk {}/{} - lunghezza: {} caratteri", chunkCounter, chunks.size(), chunk.length());
logger.info("chunk length: " + chunk.length());
summarizedChunks.add(summarizeChunk(chunk));
String summarizedChunk = summarizeChunk(chunk);
logger.info("Chunk {}/{} riassunto con successo", chunkCounter, chunks.size());
summarizedChunks.add(summarizedChunk);
}
isChunked = true;
// Unire i riassunti
String summarizedText = String.join(" ", summarizedChunks);
int tokenCountSummarizedText = encoding.get().countTokens(summarizedText);
logger.info("Testo riassunto completo: {} token, {} caratteri", tokenCountSummarizedText, summarizedText.length());
// Se il riassunto è ancora troppo lungo, applicare ricorsione
if (tokenCountSummarizedText > max_output_token) {
logger.info("Il testo riassunto è ancora troppo lungo ({}), applicando ricorsione", tokenCountSummarizedText);
return summarize(summarizedText);
} else {
logger.info("Summarizzazione completata con successo");
return summarizedText;
}
}
// Metodo per suddividere il testo in chunk
private List<String> chunkText(String text, int chunkSize) {
logger.info("Suddivisione testo in chunk di dimensione: {} caratteri", chunkSize);
List<String> chunks = new ArrayList<>();
int length = text.length();
for (int i = 0; i < length; i += chunkSize) {
chunks.add(text.substring(i, Math.min(length, i + chunkSize)));
int end = Math.min(length, i + chunkSize);
chunks.add(text.substring(i, end));
logger.debug("Creato chunk {}/{} da carattere {} a {}", chunks.size(),
(int)Math.ceil((double)length/chunkSize), i, end);
}
logger.info("Testo suddiviso in {} chunk", chunks.size());
return chunks;
}
// Metodo di riassunto per singolo chunk
private String summarizeChunk(String chunk) {
logger.info("Iniziando summarizzazione di un chunk di {} caratteri", chunk.length());
int tokenCountChunk = encoding.get().countTokens(chunk);
logger.info("Token count del chunk: {}", tokenCountChunk);
Message chunkMessage = new UserMessage(chunk);
Message systemMessage = new SystemMessage(
@@ -222,18 +285,25 @@ public class SummarizeDocSolver extends StepSolver {
logger.info("template chunk: " + this.qai_system_prompt_template_chunk);
CallResponseSpec resp = chatClient.prompt()
.messages(chunkMessage, systemMessage)
/*
* .advisors(advisor -> advisor
* .param("chat_memory_conversation_id",
* this.scenarioExecution.getId() + this.qai_custom_memory_id)
* .param("chat_memory_response_size", 100))
*/
.call();
return resp.content();
try {
CallResponseSpec resp = chatClient.prompt()
.messages(chunkMessage, systemMessage)
.call();
String result = resp.content();
if (result != null) {
int tokenCountResult = encoding.get().countTokens(result);
logger.info("Chunk riassunto con successo: {} token, {} caratteri",
tokenCountResult, result.length());
return result;
} else {
logger.warn("Risposta nulla ricevuta dal modello");
return "";
}
} catch (Exception e) {
logger.error("Errore durante la summarizzazione del chunk: {}", e.getMessage());
return "ERRORE NELLA SUMMARIZZAZIONE: " + e.getMessage();
}
}
}

View File

@@ -0,0 +1,121 @@
package com.olympus.hermione.tools;
import java.util.List;
import java.util.logging.Logger;
import org.springframework.ai.document.Document;
import org.springframework.ai.tool.annotation.Tool;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.VectorStore;
public class DocumentationTool {
private static final Logger logger = Logger.getLogger(DocumentationTool.class.getName());
private static final int DEFAULT_TOP_K = 5;
private static final double DEFAULT_SIMILARITY_THRESHOLD = 0.7;
private final String application;
private final String project;
private final VectorStore vectorStore;
/**
* Costruttore per DocumentationTool.
*
* @param application Nome dell'applicazione
* @param project Nome del progetto
* @param vectorStore Vector store contenente i documenti
*/
public DocumentationTool(String application, String project, VectorStore vectorStore) {
this.application = application;
this.project = project;
this.vectorStore = vectorStore;
}
/**
* @param query The user's question or information request regarding the application.
* @return A formatted response containing the most relevant documentation excerpts.
*
* <p>
* This method is intended for integration with conversational AI systems (such as ChatGPT),
* enabling them to provide accurate and context-aware answers about the application's functionality.
* </p>
*/
@Tool(description = "Retrieves functional information about the application based on the provided query.\r\n" + //
" * It performs a semantic search over the application's documentation and returns the most relevant results.")
public String getDocumentation(String query) {
logRequestDetails(query);
String filterExpression = buildFilterExpression();
logger.info("[TOOL]LLM DocumentationTool Getting documentation for filterExpression: " + filterExpression);
SearchRequest request = buildSearchRequest(query, filterExpression);
List<Document> docs = this.vectorStore.similaritySearch(request);
logger.info("Number of VDB retrieved documents: " + docs.size());
return formatDocumentsResponse(docs);
}
/**
* Registra i dettagli della richiesta nei log.
*
* @param query La query di ricerca
*/
private void logRequestDetails(String query) {
logger.info("[TOOL]LLM DocumentationTool Getting documentation for query: " + query);
logger.info("[TOOL]LLM DocumentationTool Getting documentation for project: " + this.project);
logger.info("[TOOL]LLM DocumentationTool Getting documentation for vectorStore: " + this.vectorStore);
}
/**
* Costruisce l'espressione di filtro per la ricerca.
*
* @return L'espressione di filtro formattata
*/
private String buildFilterExpression() {
return "'KsProjectName' == '"+ this.project +"' AND 'KsDoctype' == 'functional'";
}
/**
* Costruisce la richiesta di ricerca.
*
* @param query La query di ricerca
* @param filterExpression L'espressione di filtro
* @return La richiesta di ricerca configurata
*/
private SearchRequest buildSearchRequest(String query, String filterExpression) {
return SearchRequest.builder()
.query(query)
.topK(DEFAULT_TOP_K)
.similarityThreshold(DEFAULT_SIMILARITY_THRESHOLD)
.filterExpression(filterExpression)
.build();
}
/**
* Formatta i documenti recuperati in una risposta leggibile.
*
* @param docs Lista di documenti recuperati
* @return Risposta formattata contenente i testi dei documenti
*/
private String formatDocumentsResponse(List<Document> docs) {
if (docs.isEmpty()) {
return "Nessun documento trovato per la query.";
}
StringBuilder response = new StringBuilder();
for (Document doc : docs) {
response.append("-------USE THIS AS PART OF YOUR CONTEXT-----------------------------------------\n")
.append(doc.getText())
.append("------------------------------------------------------------------------------------------\n");
}
return response.toString();
}
}

View File

@@ -5,6 +5,7 @@ import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.document.Document;
import org.springframework.stereotype.Service;
@@ -19,7 +20,7 @@ public class RAGDocumentRanker {
Logger logger = (Logger) LoggerFactory.getLogger(RAGDocumentRanker.class);
public List<RankedDocument> rankDocuments(List<Document> documents, String query, int threshold, OlympusChatClient chatClient) {
public List<RankedDocument> rankDocuments(List<Document> documents, String query, int threshold, ChatClient chatClient) {
List<RankedDocument> rankedDocuments = new ArrayList<>();
@@ -30,8 +31,10 @@ public class RAGDocumentRanker {
rankingPrompt +="Chunk: " + document.getText();
rankingPrompt +="Answer with a number between 1 and 10 and nothing else.";
OlympusChatClientResponse resp = chatClient.getChatCompletion(rankingPrompt);
String rank = resp.getContent().substring(0, 1);
String rank = chatClient.prompt()
.user(rankingPrompt)
.call()
.content().substring(0, 1);
int rankInt = Integer.parseInt(rank);
logger.info("Rank: " + rankInt);

View File

@@ -56,9 +56,9 @@ eureka.instance.preferIpAddress: true
hermione.fe.url = http://localhost:5173
spring.ai.vectorstore.chroma.client.host=http://108.142.74.161
spring.ai.vectorstore.chroma.client.host=http://128.251.239.194
spring.ai.vectorstore.chroma.client.port=8000
spring.ai.vectorstore.chroma.client.key-token=tKAJfN1Yv5lP7pKorJHGfHMQhNEcM9uu
spring.ai.vectorstore.chroma.client.key-token=BxZWXFXC4UMSxamf5xP5SioGIg3FPfP7
spring.ai.vectorstore.chroma.initialize-schema=true
spring.ai.vectorstore.chroma.collection-name=olympus