From f43b7c2c7a8394e42c5383b55fc0cf8922a239f5 Mon Sep 17 00:00:00 2001 From: Florinda Date: Tue, 25 Feb 2025 10:57:05 +0100 Subject: [PATCH] update summarize --- .../stepSolvers/ExternalCodeGenieSolver.java | 2 +- .../stepSolvers/SummarizeDocSolver.java | 61 +++++++++++++++---- 2 files changed, 49 insertions(+), 14 deletions(-) diff --git a/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java index 644aaa2..16bb3fa 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java @@ -44,7 +44,7 @@ public class ExternalCodeGenieSolver extends StepSolver { @Autowired private ScenarioExecutionRepository scenarioExecutionRepo; - Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + Logger logger = (Logger) LoggerFactory.getLogger(ExternalCodeGenieSolver.class); private void loadParameters() { logger.info("Loading parameters"); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java index 12c7b2e..c3db65f 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java @@ -46,6 +46,10 @@ public class SummarizeDocSolver extends StepSolver { private String qai_custom_memory_id; private Integer max_output_token; private String qai_system_prompt_template_minimum; + private String qai_system_prompt_template_formatter; + private boolean isChunked = false; + private Double chunk_size_token_calc; + private Double perc = 0.2; // private boolean qai_load_graph_schema=false; Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class); @@ -75,6 +79,8 @@ public class SummarizeDocSolver extends StepSolver { this.qai_system_prompt_template_minimum = attributeParser .parse((String) this.step.getAttributes().get("qai_system_prompt_template_minimum")); + this.qai_system_prompt_template_formatter = attributeParser + .parse((String) this.step.getAttributes().get("qai_system_prompt_template_formatter")); } @Override @@ -99,13 +105,16 @@ public class SummarizeDocSolver extends StepSolver { // Conta i token // int tokenCount = encoding.get().encode(text).size(); tokenCount = encoding.get().countTokens(text); + logger.info("token count input: " + tokenCount); int charCount = text.length(); // Stima media caratteri per token // double charPerToken = (double) charCount / tokenCount; - //Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0); - Double min_output_token = (double) tokenCount * ((double) this.percent_summarize / 100.0); + // Double output_charD = (double) charCount * ((double) this.percent_summarize / 100.0); + // Integer output_char = output_charD.intValue(); + Double min_output_tokenD = (double) tokenCount * ((double) this.percent_summarize / 100.0); + Integer min_output_token = min_output_tokenD.intValue(); String content = new String(""); content = this.qai_system_prompt_template.replace("max_number_token", max_output_token.toString()); @@ -115,16 +124,32 @@ public class SummarizeDocSolver extends StepSolver { "min_number_token", min_output_token.toString()); } + + chunk_size_token_calc = (Double)((double)tokenCount * perc); + if(chunk_size_token_calc>chunk_size_token){ + chunk_size_token_calc = (double)chunk_size_token; + } // **Fase di Summarization** String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto // String template = this.qai_system_prompt_template+" The output length should // be of " + output_char + " characters"; - logger.info("template: " + content); + // Creazione dei messaggi per il modello AI Message userMessage = new UserMessage(summarizedText); - Message systemMessage = new SystemMessage(content); - logger.info("template: " + systemMessage.getContent().toString()); + Message systemMessage = null; + + int tokenCountSummary = encoding.get().countTokens(summarizedText); + + if(isChunked && tokenCountSummary < max_output_token){ + systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter); + logger.info("template formatter: " + this.qai_system_prompt_template_formatter); + }else{ + //here + systemMessage = new SystemMessage(content); + logger.info("template: " + content); + } + CallResponseSpec resp = chatClient.prompt() .messages(userMessage, systemMessage) .advisors(advisor -> advisor @@ -144,6 +169,9 @@ public class SummarizeDocSolver extends StepSolver { logger.info("Token usage information is not available."); } + tokenCount = encoding.get().countTokens(output); + logger.info("token count output: " + tokenCount); + // Salvataggio dell'output nel contesto di esecuzione this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); this.scenarioExecution.setNextStepId(this.step.getNextStepId()); @@ -160,11 +188,13 @@ public class SummarizeDocSolver extends StepSolver { // Se il testo è già corto, non riassumere logger.info("length: " + text.length()); Double chunk_size_text; + + tokenCount = encoding.get().countTokens(text); int textLengthPlus = (int) (text.length() * 1.1); int tokenCountPlus = (int) (tokenCount * 1.1); - // chunk_size_token/(ratio+10%) - // Double ratio = Math.floor((textLengthPlus / charMax) + 1); - Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); + + Double ratio = Math.floor((tokenCountPlus / chunk_size_token_calc) + 1); + //Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); if (ratio == 1) { return text; } else { @@ -174,8 +204,9 @@ public class SummarizeDocSolver extends StepSolver { // Suddividere il testo in chunk List chunks = chunkText(text, chunk_size_text.intValue()); List summarizedChunks = new ArrayList<>(); - - Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio); + + //Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio); + Double maxTokenChunkD = Math.ceil(chunk_size_token_calc / ratio); maxTokenChunk = maxTokenChunkD.intValue(); // Riassumere ogni chunk singolarmente @@ -184,12 +215,13 @@ public class SummarizeDocSolver extends StepSolver { summarizedChunks.add(summarizeChunk(chunk)); } + isChunked = true; // Unire i riassunti String summarizedText = String.join(" ", summarizedChunks); int tokenCountSummarizedText = encoding.get().countTokens(summarizedText); // Se il riassunto è ancora troppo lungo, applicare ricorsione - if (tokenCountSummarizedText > chunk_size_token) { + if (tokenCountSummarizedText > max_output_token) { return summarize(summarizedText); } else { return summarizedText; @@ -210,13 +242,16 @@ public class SummarizeDocSolver extends StepSolver { private String summarizeChunk(String chunk) { String content = new String(""); - if (maxTokenChunk < max_output_token) { + /* if (maxTokenChunk < max_output_token) { content = this.qai_system_prompt_template_chunk.replace("max_number_token", maxTokenChunk.toString()); }else{ content = this.qai_system_prompt_template_chunk.replace("max_number_token", max_output_token.toString()); - } + }*/ + + content = this.qai_system_prompt_template_chunk.replace("max_number_token", + maxTokenChunk.toString()); Message chunkMessage = new UserMessage(chunk); Message systemMessage = new SystemMessage(