Enhance logging in AdvancedAIPromptSolver and AdvancedQueryRagSolver for better traceability

This commit is contained in:
andrea.terzani
2025-05-22 10:16:40 +02:00
parent 8d95d410d9
commit 1a1ead8fce
2 changed files with 78 additions and 57 deletions

View File

@@ -33,50 +33,68 @@ public class AdvancedAIPromptSolver extends StepSolver {
logger.info("Loading parameters");
if(this.step.getAttributes().get("qai_load_graph_schema")!=null ){
logger.info("Found attribute 'qai_load_graph_schema': {}", this.step.getAttributes().get("qai_load_graph_schema"));
this.qai_load_graph_schema = Boolean.parseBoolean((String) this.step.getAttributes().get("qai_load_graph_schema"));
logger.info("qai_load_graph_schema: " + this.qai_load_graph_schema);
logger.info("qai_load_graph_schema: {}", this.qai_load_graph_schema);
} else {
logger.info("Attribute 'qai_load_graph_schema' not found, using default: false");
}
if(this.qai_load_graph_schema){
logger.info("Loading graph schema");
this.scenarioExecution.getExecSharedMap().put("graph_schema", this.neo4JUitilityService.generateSchemaDescription().toString());
Object schema = this.neo4JUitilityService.generateSchemaDescription();
logger.debug("Generated graph schema: {}", schema);
this.scenarioExecution.getExecSharedMap().put("graph_schema", schema.toString());
}
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
logger.info("Parsed qai_system_prompt_template: {}", this.qai_system_prompt_template);
this.qai_user_input = attributeParser.parse((String)this.step.getAttributes().get("qai_user_input"));
logger.info("Parsed qai_user_input: {}", this.qai_user_input);
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
logger.info("qai_output_variable: {}", this.qai_output_variable);
this.qai_output_entityType = (String) this.step.getAttributes().get("qai_output_entityType");
logger.info("qai_output_entityType: {}", this.qai_output_entityType);
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
logger.info("qai_custom_memory_id: {}", this.qai_custom_memory_id);
//TODO: Add memory ID attribute to have the possibility of multiple conversations
this.qai_available_tools = (String) this.step.getAttributes().get("qai_available_tools");
logger.info("qai_available_tools: {}", this.qai_available_tools);
}
@Override
public ScenarioExecution solveStep(){
logger.info("Solving step: {}", this.step.getName());
System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
logger.debug("Set current step ID: {}", this.step.getStepId());
loadParameters();
Object tools = null;
if(this.qai_available_tools!=null && !this.qai_available_tools.isEmpty()){
logger.info("Available tools specified: {}", this.qai_available_tools);
for(String tool: this.qai_available_tools.split(",")){
logger.info("Tool: " + tool);
logger.info("Tool: {}", tool);
if(tool.equals("SourceCodeTool")){
logger.info("Instantiating SourceCodeTool");
tools = new com.olympus.hermione.tools.SourceCodeTool(discoveryClient);
} else {
logger.warn("Unknown tool specified: {}", tool);
}
}
} else {
logger.info("No tools specified or available");
}
CallResponseSpec resp=null;
if(tools==null){
logger.info("No tools available");
logger.info("No tools available, calling chatClient without tools");
resp = chatClient.prompt()
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
@@ -85,6 +103,7 @@ public class AdvancedAIPromptSolver extends StepSolver {
.param("chat_memory_response_size", 100))
.call();
}else{
logger.info("Calling chatClient with tools: {}", tools.getClass().getName());
resp = chatClient.prompt()
.user(this.qai_user_input)
.system(this.qai_system_prompt_template)
@@ -95,37 +114,39 @@ public class AdvancedAIPromptSolver extends StepSolver {
.call();
}
if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){
logger.info("Output is of type CiaOutputEntity");
CiaOutputEntity ouputEntity = resp.entity(CiaOutputEntity.class);
logger.debug("CiaOutputEntity response: {}", ouputEntity);
try {
ObjectMapper objectMapper = new ObjectMapper();
String jsonOutput = objectMapper.writeValueAsString(ouputEntity);
logger.info("Storing output entity as JSON in shared map with key: {}", this.qai_output_variable);
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, jsonOutput);
} catch (JsonProcessingException e) {
logger.error("Error converting output entity to JSON", e);
}
}else{
logger.info("Output is of type String");
String output = resp.content();
logger.debug("String response: {}", output);
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
}
Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) {
Integer usedTokens = usage.getTotalTokens();
logger.info("Used tokens: {}", usedTokens);
this.scenarioExecution.setUsedTokens(usedTokens);
} else {
logger.info("Token usage information is not available.");
}
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
logger.info("Set next step ID: {}", this.step.getNextStepId());
logger.info("Step solved. Returning scenarioExecution.");
return this.scenarioExecution;
}

View File

@@ -22,7 +22,7 @@ import ch.qos.logback.classic.Logger;
public class AdvancedQueryRagSolver extends StepSolver {
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
Logger logger = (Logger) LoggerFactory.getLogger(AdvancedQueryRagSolver.class);
// PARAMETERS
// rag_input : the field of the sharedMap to use as rag input
@@ -160,47 +160,48 @@ public class AdvancedQueryRagSolver extends StepSolver {
@Override
public ScenarioExecution solveStep(){
logger.info("Solving step: " + this.step.getName());
logger.info("Solving step: {}", this.step.getName());
loadParameters();
logParameters();
String elaboratedQuery = this.query;
OlympusChatClient chatClient = null;
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
String apiProvider = this.scenarioExecution.getScenario().getAiModel().getApiProvider();
logger.info("AI Model API Provider: {}", apiProvider);
if (apiProvider.equals("AzureOpenAI")) {
logger.info("Using AzureOpenApiChatClient");
chatClient = new AzureOpenApiChatClient();
}
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) {
if (apiProvider.equals("GoogleGemini")) {
logger.info("Using GoogleGeminiChatClient");
chatClient = new GoogleGeminiChatClient();
}
chatClient.init(scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(),
scenarioExecution.getScenario().getAiModel().getApiKey(),
scenarioExecution.getScenario().getAiModel().getMaxTokens());
logger.info("Initializing chat client with endpoint: {}", this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint());
chatClient.init(this.scenarioExecution.getScenario().getAiModel().getFull_path_endpoint(),
this.scenarioExecution.getScenario().getAiModel().getApiKey(),
this.scenarioExecution.getScenario().getAiModel().getMaxTokens());
if( this.enableRephrase ){
if(this.enableRephrase){
logger.info("Rephrasing query with prompt: {}", this.rephrasePrompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(this.rephrasePrompt);
elaboratedQuery = resp.getContent();
logger.info("Rephrased query: {}", elaboratedQuery);
}
if( this.enableMultiPhrase ){
if(this.enableMultiPhrase){
logger.info("Generating multi-phrase queries with prompt: {}", this.MultiPhrasePrompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(this.MultiPhrasePrompt);
elaboratedQuery = resp.getContent();
logger.info("Multi-phrase queries: {}", elaboratedQuery);
}
List<Document> docs = new ArrayList<Document>();
List<String> queries = new ArrayList<String>();
List<Document> docs = new ArrayList<>();
List<String> queries = new ArrayList<>();
for (String query : elaboratedQuery.split("\n")) {
logger.info("Elaborated query: " + query);
logger.info("Elaborated query: {}", query);
queries.add(query);
Builder request_builder = SearchRequest.builder()
.query(elaboratedQuery)
@@ -209,47 +210,45 @@ public class AdvancedQueryRagSolver extends StepSolver {
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: " + this.rag_filter);
logger.info("Using Filter expression: {}", this.rag_filter);
}
SearchRequest request = request_builder.build();
docs.addAll( this.vectorStore.similaritySearch(request));
logger.info("Performing similarity search for query: {}", query);
docs.addAll(this.vectorStore.similaritySearch(request));
logger.info("Documents found so far: {}", docs.size());
}
logger.info("Number of not unique VDB retrieved documents: " + docs.size());
logger.info("Number of not unique VDB retrieved documents: {}", docs.size());
docs = docs.stream().collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
logger.info("Number of VDB retrieved documents: " + docs.size());
logger.info("Number of VDB retrieved documents (unique): {}", docs.size());
if (enableRanking){
List<Document> rankedDocs = new ArrayList<Document>();
List<Document> rankedDocs = new ArrayList<>();
logger.info("Ranking documents");
RAGDocumentRanker ranker = new RAGDocumentRanker();
List<RankedDocument> rankedDocument = ranker.rankDocuments(docs, this.query, this.rank_threshold, chatClient);
for (RankedDocument rankedDoc : rankedDocument) {
rankedDocs.add(rankedDoc.getDocument());
}
logger.info("Number of ranked documents: " + rankedDocs.size());
logger.info("Number of ranked documents: {}", rankedDocs.size());
docs = rankedDocs;
}
if(enableNeighborRetrieve){
logger.info("Retrieving neighbor documents: #PRE " + this.preNeighbor + " #POST " + this.afterNeighbor);
List<Document> neighborDocs = new ArrayList<Document>();
logger.info("Retrieving neighbor documents: #PRE {} #POST {}", this.preNeighbor, this.afterNeighbor);
List<Document> neighborDocs = new ArrayList<>();
for(Document doc : docs){
neighborDocs.addAll( retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor));
logger.info("Number of neighbor documents: " + neighborDocs.size());
logger.info("Retrieving neighbors for doc ID: {}", doc.getId());
neighborDocs.addAll(retrieveNeighborDocuments(doc, preNeighbor, afterNeighbor));
logger.info("Number of neighbor documents for this doc: {}", neighborDocs.size());
}
docs.addAll(neighborDocs);
logger.info("Number of documents after neighbor retrieval: " + docs.size());
logger.info("Number of documents after neighbor retrieval: {}", docs.size());
}
if(enableRetrieveOther){
List<Document> otherDocs = new ArrayList<Document>();
List<Document> otherDocs = new ArrayList<>();
logger.info("Retrieving other documents");
String resultString = "";
@@ -266,10 +265,11 @@ public class AdvancedQueryRagSolver extends StepSolver {
prompt += "\n\n The query is: " + this.query;
prompt += "\n\n The output must be a single query and nothing else.";
logger.info("Prompt for retrieving other documents: {}", prompt);
OlympusChatClientResponse resp = chatClient.getChatCompletion(prompt);
String otherDocQuery = resp.getContent();
logger.info("Other document query: " + otherDocQuery);
logger.info("Other document query: {}", otherDocQuery);
Builder request_builder = SearchRequest.builder()
.query(otherDocQuery)
@@ -278,24 +278,24 @@ public class AdvancedQueryRagSolver extends StepSolver {
if(this.rag_filter != null && !this.rag_filter.isEmpty()){
request_builder.filterExpression(this.rag_filter);
logger.info("Using Filter expression: " + this.rag_filter);
logger.info("Using Filter expression: {}", this.rag_filter);
}
SearchRequest request = request_builder.build();
otherDocs.addAll( this.vectorStore.similaritySearch(request));
logger.info("Number of addictional documents: " + otherDocs.size());
logger.info("Performing similarity search for other document query");
otherDocs.addAll(this.vectorStore.similaritySearch(request));
logger.info("Number of addictional documents: {}", otherDocs.size());
docs.addAll(otherDocs);
}
docs = docs.stream()
.filter(d -> d.getMetadata().get("KsDocumentId") != null)
.filter(d -> d.getMetadata().get("KsDocumentIndex") != null)
.collect(Collectors.toMap(Document::getId, d -> d, (d1, d2) -> d1)).values().stream().collect(Collectors.toList());
//Sort by KsDocumentId and KsDocumentIndex
logger.info("Sorting documents by KsDocumentId and KsDocumentIndex");
docs.sort((d1, d2) -> {
String docId1 = (String) d1.getMetadata().get("KsDocumentId");
String docId2 = (String) d2.getMetadata().get("KsDocumentId");
@@ -309,24 +309,24 @@ public class AdvancedQueryRagSolver extends StepSolver {
}
});
List<String> source_doc = new ArrayList<String>();
List<String> source_doc = new ArrayList<>();
String resultString = "";
for(Document doc : docs){
logger.info("Adding document to result: Source {} - Chunk Index {}", doc.getMetadata().get("KsFileSource"), doc.getMetadata().get("KsDocumentIndex"));
resultString += "Document Source: " + doc.getMetadata().get("KsFileSource") + " - Chunk Index " + doc.getMetadata().get("KsDocumentIndex")+ "\n";
resultString += "--------------------------------------------------------------\n";
resultString += doc.getText() + "\n";
source_doc.add((String)doc.getMetadata().get("KsFileSource"));
}
logger.info("Storing results in scenarioExecution shared map. Output field: {}", this.outputField);
this.scenarioExecution.getExecSharedMap().put("tech_rag_source_documents", source_doc);
this.scenarioExecution.getExecSharedMap().put("tech_rag_query", queries);
this.scenarioExecution.getExecSharedMap().put(this.outputField, resultString);
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
logger.info("Step risolto. Restituisco scenarioExecution.");
return this.scenarioExecution;
}