Refactor token handling and update dependencies; add AzureAIConfig for response timeout customization

This commit is contained in:
andrea.terzani
2025-02-21 13:02:37 +01:00
parent b0e4abc172
commit 3402b4ada2
13 changed files with 122 additions and 73 deletions

View File

@@ -29,7 +29,7 @@
<properties> <properties>
<java.version>21</java.version> <java.version>21</java.version>
<!--<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>--> <!--<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>-->
<spring-ai.version>1.0.0-M2</spring-ai.version> <spring-ai.version>1.0.0-M6</spring-ai.version>
<spring-cloud.version>2023.0.3</spring-cloud.version> <spring-cloud.version>2023.0.3</spring-cloud.version>
</properties> </properties>
<dependencies> <dependencies>
@@ -48,6 +48,13 @@
<version>1.1.0</version> <version>1.1.0</version>
</dependency> </dependency>
<!-- https://mvnrepository.com/artifact/org.json/json -->
<dependency>
<groupId>org.json</groupId>
<artifactId>json</artifactId>
<version>20250107</version>
</dependency>
<!--<dependency> <!--<dependency>
<groupId>org.springframework.ai</groupId> <groupId>org.springframework.ai</groupId>

View File

@@ -0,0 +1,26 @@
package com.olympus.hermione.config;
import java.time.Duration;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.client.RestTemplate;
import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAIClientBuilderCustomizer;
import com.azure.core.http.HttpClient;
import com.azure.core.util.HttpClientOptions;
@Configuration
public class AzureAIConfig {
@Bean
public AzureOpenAIClientBuilderCustomizer responseTimeoutCustomizer() {
return openAiClientBuilder -> {
HttpClientOptions clientOptions = new HttpClientOptions()
.setResponseTimeout(Duration.ofMinutes(5));
openAiClientBuilder.httpClient(HttpClient.createDefault(clientOptions));
};
}
}

View File

@@ -43,7 +43,8 @@ public class VectorStoreConfig {
fields.add(AzureVectorStore.MetadataField.text("KsFileSource")); fields.add(AzureVectorStore.MetadataField.text("KsFileSource"));
fields.add(AzureVectorStore.MetadataField.text("KsDocumentId")); fields.add(AzureVectorStore.MetadataField.text("KsDocumentId"));
return new AzureVectorStore(searchIndexClient, embeddingModel,initSchema, fields); //return new AzureVectorStore(searchIndexClient, embeddingModel,initSchema, fields);
return null;
} }
} }

View File

@@ -1,15 +1,12 @@
package com.olympus.hermione.models; package com.olympus.hermione.models;
import java.util.HashMap;
import java.util.Date; import java.util.Date;
import java.util.HashMap;
import org.springframework.data.annotation.Id; import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document; import org.springframework.data.mongodb.core.mapping.Document;
import org.springframework.data.mongodb.core.mapping.DocumentReference;
import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioExecutionInput;
import com.olympus.hermione.models.Scenario;
import com.olympus.hermione.security.entity.User;
import lombok.Getter; import lombok.Getter;
import lombok.Setter; import lombok.Setter;
@@ -38,6 +35,6 @@ public class ScenarioExecution {
private Date startDate; private Date startDate;
private Date endDate; private Date endDate;
private Long usedTokens; private Integer usedTokens;
private String rating; private String rating;
} }

View File

@@ -41,7 +41,7 @@ public class CanvasExecutionService {
Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); Prompt prompt = new Prompt(List.of(userMessage,systemMessage));
List<Generation> response = chatModel.call(prompt).getResults(); List<Generation> response = chatModel.call(prompt).getResults();
canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); canvasOutput.setStringOutput(response.get(0).getOutput().getText());
return canvasOutput; return canvasOutput;
} else{ } else{
logger.error("Input is not correct"); logger.error("Input is not correct");
@@ -61,7 +61,7 @@ public class CanvasExecutionService {
Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); Prompt prompt = new Prompt(List.of(userMessage,systemMessage));
List<Generation> response = chatModel.call(prompt).getResults(); List<Generation> response = chatModel.call(prompt).getResults();
canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); canvasOutput.setStringOutput(response.get(0).getOutput().getText());
return canvasOutput; return canvasOutput;
} else{ } else{
logger.error("Input is not correct"); logger.error("Input is not correct");
@@ -79,7 +79,7 @@ public class CanvasExecutionService {
Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); Prompt prompt = new Prompt(List.of(userMessage,systemMessage));
List<Generation> response = chatModel.call(prompt).getResults(); List<Generation> response = chatModel.call(prompt).getResults();
canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); canvasOutput.setStringOutput(response.get(0).getOutput().getText());
return canvasOutput; return canvasOutput;
} else{ } else{
logger.error("Input is not correct"); logger.error("Input is not correct");

View File

@@ -34,7 +34,6 @@ import org.springframework.scheduling.annotation.Async;
import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioExecutionInput;
import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.models.AiModel; import com.olympus.hermione.models.AiModel;
@@ -210,7 +209,7 @@ public class ScenarioExecutionService {
ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null);
executeScenarioStep(startStep, scenarioExecution); executeScenarioStep(startStep, scenarioExecution);
Long usedTokens = (scenarioExecution.getUsedTokens() != null) ? scenarioExecution.getUsedTokens() : Long.valueOf(0); Integer usedTokens = (scenarioExecution.getUsedTokens() != null) ? scenarioExecution.getUsedTokens() : 0;
while (scenarioExecution.getNextStepId()!=null) { while (scenarioExecution.getNextStepId()!=null) {
ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null);
@@ -219,7 +218,7 @@ public class ScenarioExecutionService {
if (scenarioExecution.getUsedTokens() != null && scenarioExecution.getUsedTokens() != 0) { if (scenarioExecution.getUsedTokens() != null && scenarioExecution.getUsedTokens() != 0) {
usedTokens += scenarioExecution.getUsedTokens(); usedTokens += scenarioExecution.getUsedTokens();
scenarioExecution.setUsedTokens(Long.valueOf(0)); //resetting value for next step if is not an AI step scenarioExecution.setUsedTokens(0); //resetting value for next step if is not an AI step
} }
if(scenarioExecution.getLatestStepStatus() != null && scenarioExecution.getLatestStepStatus().equals("ERROR")){ if(scenarioExecution.getLatestStepStatus() != null && scenarioExecution.getLatestStepStatus().equals("ERROR")){
@@ -426,17 +425,19 @@ public class ScenarioExecutionService {
private ChatModel createChatModel(AiModel aiModel){ private ChatModel createChatModel(AiModel aiModel){
switch(aiModel.getApiProvider()){ switch(aiModel.getApiProvider()){
case "AzureOpenAI": case "AzureOpenAI":
OpenAIClient openAIClient = new OpenAIClientBuilder() OpenAIClientBuilder openAIClient = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey())) .credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint()) .endpoint(aiModel.getEndpoint());
.buildClient();
AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder() AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(aiModel.getModel()) .deploymentName(aiModel.getModel())
.withTemperature(aiModel.getTemperature()) .maxTokens(aiModel.getMaxTokens())
.withMaxTokens(aiModel.getMaxTokens()) .temperature(Double.valueOf(aiModel.getTemperature()))
.build(); .build();
AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions); AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions);
logger.info("AI model used: " + aiModel.getModel()); logger.info("AI model used: " + aiModel.getModel());
return azureOpenaichatModel; return azureOpenaichatModel;
@@ -445,9 +446,9 @@ public class ScenarioExecutionService {
case "OpenAI": case "OpenAI":
OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey()); OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey());
OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder() OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder()
.withModel(aiModel.getModel()) .model(aiModel.getModel())
.withTemperature(aiModel.getTemperature()) .temperature(Double.valueOf(aiModel.getTemperature()))
.withMaxTokens(aiModel.getMaxTokens()) .maxTokens(aiModel.getMaxTokens())
.build(); .build();
OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions); OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions);
@@ -456,19 +457,21 @@ public class ScenarioExecutionService {
case "GoogleGemini": case "GoogleGemini":
OpenAIClient tempopenAIClient = new OpenAIClientBuilder() OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey())) .credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint()) .endpoint(aiModel.getEndpoint());
.buildClient();
AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(aiModel.getModel()) AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder()
.withTemperature(aiModel.getTemperature()) .deploymentName(aiModel.getModel())
.withMaxTokens(aiModel.getMaxTokens()) .maxTokens(aiModel.getMaxTokens())
.temperature(Double.valueOf(aiModel.getTemperature()))
.build(); .build();
AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions); AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2, openAIChatOptions2);
logger.info("AI model used: GoogleGemini");
return tempazureOpenaichatModel; logger.info("AI model used : " + aiModel.getModel());
return azureOpenaichatModel2;
default: default:
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName()); throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
} }

View File

@@ -92,7 +92,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
Usage usage = resp.chatResponse().getMetadata().getUsage(); Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) { if (usage != null) {
Long usedTokens = usage.getTotalTokens(); Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens); this.scenarioExecution.setUsedTokens(usedTokens);
} else { } else {
logger.info("Token usage information is not available."); logger.info("Token usage information is not available.");

View File

@@ -65,7 +65,7 @@ public class BasicAIPromptSolver extends StepSolver {
Usage usage = resp.chatResponse().getMetadata().getUsage(); Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) { if (usage != null) {
Long usedTokens = usage.getTotalTokens(); Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens); this.scenarioExecution.setUsedTokens(usedTokens);
} else { } else {
logger.info("Token usage information is not available."); logger.info("Token usage information is not available.");

View File

@@ -6,8 +6,7 @@ import java.util.ArrayList;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.SearchRequest.Builder;
import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser; import com.olympus.hermione.utility.AttributeParser;
@@ -72,22 +71,29 @@ public class BasicQueryRagSolver extends StepSolver {
loadParameters(); loadParameters();
logParameters(); logParameters();
SearchRequest request = SearchRequest.defaults()
.withQuery(this.query)
.withTopK(this.topk) Builder request_builder = SearchRequest.builder()
.withSimilarityThreshold(this.threshold); .query(this.query)
.topK(this.topk)
.similarityThreshold(this.threshold);
if(this.rag_filter != null && !this.rag_filter.isEmpty()){ if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request.withFilterExpression(this.rag_filter); request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: " + this.rag_filter); logger.info("Using Filter expression: " + this.rag_filter);
} }
SearchRequest request = request_builder.build();
List<Document> docs = this.vectorStore.similaritySearch(request); List<Document> docs = this.vectorStore.similaritySearch(request);
logger.info("Number of VDB retrieved documents: " + docs.size()); logger.info("Number of VDB retrieved documents: " + docs.size());
List<String> result = new ArrayList<String>(); List<String> result = new ArrayList<String>();
for (Document doc : docs) { for (Document doc : docs) {
result.add(doc.getContent()); result.add(doc.getText());
} }
//concatenate the content of the results into a single string //concatenate the content of the results into a single string

View File

@@ -7,6 +7,7 @@ import java.util.List;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.document.Document; import org.springframework.ai.document.Document;
import org.springframework.ai.vectorstore.SearchRequest; import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
public class DeleteDocTempSolver extends StepSolver { public class DeleteDocTempSolver extends StepSolver {
@@ -56,7 +57,13 @@ public class DeleteDocTempSolver extends StepSolver {
this.scenarioExecution.setCurrentStepId(this.step.getStepId()); this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters(); loadParameters();
SearchRequest searchRequest = SearchRequest.defaults()
vectorStore.delete(rag_filter);
/*SearchRequest searchRequest = SearchRequest.defaults()
.withQuery(this.query) .withQuery(this.query)
.withTopK(this.topk) .withTopK(this.topk)
.withSimilarityThreshold(this.threshold) .withSimilarityThreshold(this.threshold)
@@ -65,7 +72,7 @@ public class DeleteDocTempSolver extends StepSolver {
List<Document> docs = vectorStore.similaritySearch(searchRequest); List<Document> docs = vectorStore.similaritySearch(searchRequest);
List<String> ids = docs.stream().map(Document::getId).toList(); List<String> ids = docs.stream().map(Document::getId).toList();
vectorStore.delete(ids); vectorStore.delete(ids);*/
this.scenarioExecution.setNextStepId(this.step.getNextStepId()); this.scenarioExecution.setNextStepId(this.step.getNextStepId());

View File

@@ -24,7 +24,6 @@ public class OlynmpusChatClientSolver extends StepSolver{
private String qai_system_prompt_template; private String qai_system_prompt_template;
private String qai_user_input; private String qai_user_input;
private String qai_output_variable; private String qai_output_variable;
private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
@@ -73,8 +72,7 @@ public class OlynmpusChatClientSolver extends StepSolver{
Long usedTokens = (long) resp.getTotalTokens(); this.scenarioExecution.setUsedTokens(resp.getTotalTokens());
this.scenarioExecution.setUsedTokens(usedTokens);
this.scenarioExecution.setNextStepId(this.step.getNextStepId()); this.scenarioExecution.setNextStepId(this.step.getNextStepId());

View File

@@ -124,7 +124,7 @@ public class SummarizeDocSolver extends StepSolver {
// Creazione dei messaggi per il modello AI // Creazione dei messaggi per il modello AI
Message userMessage = new UserMessage(summarizedText); Message userMessage = new UserMessage(summarizedText);
Message systemMessage = new SystemMessage(content); Message systemMessage = new SystemMessage(content);
logger.info("template: " + systemMessage.getContent().toString()); logger.info("template: " + systemMessage.getText());
CallResponseSpec resp = chatClient.prompt() CallResponseSpec resp = chatClient.prompt()
.messages(userMessage, systemMessage) .messages(userMessage, systemMessage)
.advisors(advisor -> advisor .advisors(advisor -> advisor
@@ -138,7 +138,7 @@ public class SummarizeDocSolver extends StepSolver {
// Gestione del conteggio token usati // Gestione del conteggio token usati
Usage usage = resp.chatResponse().getMetadata().getUsage(); Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) { if (usage != null) {
Long usedTokens = usage.getTotalTokens(); Integer usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens); this.scenarioExecution.setUsedTokens(usedTokens);
} else { } else {
logger.info("Token usage information is not available."); logger.info("Token usage information is not available.");

View File

@@ -27,9 +27,6 @@ spring.main.allow-circular-references=true
spring.ai.azure.openai.api-key=4eHwvw6h7vHxTmI2870cR3EpEBs5L9sXZabr9nz37y39TXtk0xY5JQQJ99AKAC5RqLJXJ3w3AAABACOGLdow spring.ai.azure.openai.api-key=4eHwvw6h7vHxTmI2870cR3EpEBs5L9sXZabr9nz37y39TXtk0xY5JQQJ99AKAC5RqLJXJ3w3AAABACOGLdow
spring.ai.azure.openai.endpoint=https://ai-olympus-new.openai.azure.com/ spring.ai.azure.openai.endpoint=https://ai-olympus-new.openai.azure.com/
#spring.ai.azure.openai.api-key=9fb33cc69d914d4c8225b974876510b5
#spring.ai.azure.openai.endpoint=https://ai-olympus.openai.azure.com/
spring.ai.azure.openai.chat.options.deployment-name=gpt-4o-mini spring.ai.azure.openai.chat.options.deployment-name=gpt-4o-mini
spring.ai.azure.openai.chat.options.temperature=0.7 spring.ai.azure.openai.chat.options.temperature=0.7
@@ -66,8 +63,15 @@ spring.ai.vectorstore.chroma.client.port=8000
spring.ai.vectorstore.chroma.client.key-token=tKAJfN1Yv5lP7pKorJHGfHMQhNEcM9uu spring.ai.vectorstore.chroma.client.key-token=tKAJfN1Yv5lP7pKorJHGfHMQhNEcM9uu
spring.ai.vectorstore.chroma.initialize-schema=true spring.ai.vectorstore.chroma.initialize-schema=true
spring.ai.vectorstore.chroma.collection-name=olympus spring.ai.vectorstore.chroma.collection-name=olympus
file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/
spring.servlet.multipart.max-file-size=10MB spring.servlet.multipart.max-file-size=10MB
spring.servlet.multipart.max-request-size=10MB spring.servlet.multipart.max-request-size=10MB
file.upload-dir=C:\\mnt\\hermione_storage\\documents\\file_input_scenarios\\
generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080 generic-file-parser-module.url=http://generic-file-parser-module-service.olympus.svc.cluster.local:8080
java-parser-module.url: http://java-parser-module-service.olympus.svc.cluster.local:8080
java-re-module.url: http://java-re-module-service.olympus.svc.cluster.local:8080
jsp-parser-module.url: http://jsp-parser-module-service.olympus.svc.cluster.local:8080