diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 7a0f127..639655a 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository; import com.olympus.hermione.repository.ScenarioRepository; import com.olympus.hermione.security.entity.User; import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver; +import com.olympus.hermione.stepSolvers.SummarizeDocSolver; import com.olympus.hermione.stepSolvers.BasicAIPromptSolver; import com.olympus.hermione.stepSolvers.BasicQueryRagSolver; import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; @@ -258,6 +259,9 @@ public class ScenarioExecutionService { case "EXTERNAL_CODEGENIE": solver = new ExternalCodeGenieSolver(); break; + case "SUMMARIZE_DOC": + solver = new SummarizeDocSolver(); + break; default: break; } @@ -304,12 +308,13 @@ public class ScenarioExecutionService { logger.info("Executing scenario: " + scenario.getName()); + String folder_name = ""; ScenarioExecution scenarioExecution = new ScenarioExecution(); scenarioExecution.setScenario(scenario); //prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload" if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){ - String folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload"); + folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload"); if(folder_name!=null && !folder_name.equals("")){ try{ String base64 = folderToBase64(folder_name); @@ -319,6 +324,9 @@ public class ScenarioExecutionService { } } } + if(scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")){ + scenarioExecutionInput.getInputs().put("SingleFileUpload", uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload")); + } scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput); try{ diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java new file mode 100644 index 0000000..759a004 --- /dev/null +++ b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java @@ -0,0 +1,217 @@ +package com.olympus.hermione.stepSolvers; + +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.nio.file.Files; +import java.util.ArrayList; +import java.util.Base64; +import java.util.List; +import java.util.Optional; + +import org.apache.tika.Tika; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.core.io.ClassPathResource; +import org.springframework.core.io.InputStreamResource; +import org.springframework.util.MimeTypeUtils; +import org.springframework.ai.model.Media; +import com.knuddels.jtokkit.Encodings; +import com.knuddels.jtokkit.api.Encoding; +import com.knuddels.jtokkit.api.EncodingRegistry; +import java.io.File; +import java.io.FileInputStream; + +import com.olympus.hermione.models.ScenarioExecution; +import com.olympus.hermione.utility.AttributeParser; + +import ch.qos.logback.classic.Logger; + +public class SummarizeDocSolver extends StepSolver { + + private String qai_system_prompt_template; + private String qai_system_prompt_template_chunk; + private String qai_path_file; + private String qai_output_variable; + private int chunk_size_token; + private String encoding_name; + private double charMax; + private int tokenCount; + private Optional encoding; + private int percent_summarize; + private Integer maxTokenChunk; + private String qai_custom_memory_id; + // private boolean qai_load_graph_schema=false; + + Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class); + + private void loadParameters() { + logger.info("Loading parameters"); + + AttributeParser attributeParser = new AttributeParser(this.scenarioExecution); + + this.qai_system_prompt_template = attributeParser + .parse((String) this.step.getAttributes().get("qai_system_prompt_template")); + this.qai_system_prompt_template_chunk = attributeParser + .parse((String) this.step.getAttributes().get("qai_system_prompt_template_chunk")); + this.qai_path_file = attributeParser.parse((String) this.step.getAttributes().get("qai_path_file")); + + this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable"); + this.encoding_name = (String) this.step.getAttributes().get("encoding_name"); + + this.chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("chunk_size_token")); + + this.percent_summarize = Integer.parseInt((String) this.step.getAttributes().get("percent_summarize")); + + this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id"); + } + + @Override + public ScenarioExecution solveStep() { + System.out.println("Solving step: " + this.step.getName()); + + try { + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); + + // Estrazione del testo dal file + File file = new File(this.qai_path_file); + Tika tika = new Tika(); + tika.setMaxStringLength(-1); + String text = tika.parseToString(file); + + // add libreria tokenizer + // Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo) + EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); + encoding = registry.getEncoding(this.encoding_name); + + // Conta i token + // int tokenCount = encoding.get().encode(text).size(); + tokenCount = encoding.get().countTokens(text); + int charCount = text.length(); + + // Stima media caratteri per token + double charPerToken = (double) charCount / tokenCount; + + Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0); + Integer output_char_int = output_char.intValue(); + this.qai_system_prompt_template = this.qai_system_prompt_template.replace("number_char", + output_char_int.toString()); + // **Fase di Summarization** + String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto + // String template = this.qai_system_prompt_template+" The output length should + // be of " + output_char + " characters"; + logger.info("template: " + this.qai_system_prompt_template); + // Creazione dei messaggi per il modello AI + Message userMessage = new UserMessage(summarizedText); + Message systemMessage = new SystemMessage(this.qai_system_prompt_template); + + CallResponseSpec resp = chatClient.prompt() + .messages(userMessage, systemMessage) + .advisors(advisor -> advisor + .param("chat_memory_conversation_id", + this.scenarioExecution.getId() + this.qai_custom_memory_id) + .param("chat_memory_response_size", 100)) + .call(); + + String output = resp.content(); + + // Gestione del conteggio token usati + Usage usage = resp.chatResponse().getMetadata().getUsage(); + if (usage != null) { + Long usedTokens = usage.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + } else { + logger.info("Token usage information is not available."); + } + + // Salvataggio dell'output nel contesto di esecuzione + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + + } catch (Exception e) { + logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage()); + this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error"); + } + + return this.scenarioExecution; + } + + private String summarize(String text) { + // Se il testo è già corto, non riassumere + logger.info("length: " + text.length()); + Double chunk_size_text; + int textLengthPlus = (int) (text.length() * 1.1); + int tokenCountPlus = (int) (tokenCount * 1.1); + // chunk_size_token/(ratio+10%) + // Double ratio = Math.floor((textLengthPlus / charMax) + 1); + Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); + if (ratio == 1) { + return text; + } else { + chunk_size_text = Math.ceil(textLengthPlus / ratio); + } + + // Suddividere il testo in chunk + List chunks = chunkText(text, chunk_size_text.intValue()); + List summarizedChunks = new ArrayList<>(); + + Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio); + maxTokenChunk = maxTokenChunkD.intValue(); + + // Riassumere ogni chunk singolarmente + for (String chunk : chunks) { + logger.info("chunk length: " + chunk.length()); + summarizedChunks.add(summarizeChunk(chunk)); + } + + // Unire i riassunti + String summarizedText = String.join(" ", summarizedChunks); + int tokenCountSummarizedText = encoding.get().countTokens(summarizedText); + + // Se il riassunto è ancora troppo lungo, applicare ricorsione + if (tokenCountSummarizedText > chunk_size_token) { + return summarize(summarizedText); + } else { + return summarizedText; + } + } + + // Metodo per suddividere il testo in chunk + private List chunkText(String text, int chunkSize) { + List chunks = new ArrayList<>(); + int length = text.length(); + for (int i = 0; i < length; i += chunkSize) { + chunks.add(text.substring(i, Math.min(length, i + chunkSize))); + } + return chunks; + } + + // Metodo di riassunto per singolo chunk + private String summarizeChunk(String chunk) { + try { + this.qai_system_prompt_template_chunk = this.qai_system_prompt_template_chunk.replace("number_char", + maxTokenChunk.toString()); + + Message chunkMessage = new UserMessage(chunk); + Message systemMessage = new SystemMessage(qai_system_prompt_template_chunk); + + CallResponseSpec resp = chatClient.prompt() + .messages(chunkMessage, systemMessage) + .advisors(advisor -> advisor + .param("chat_memory_conversation_id", + this.scenarioExecution.getId() + this.qai_custom_memory_id) + .param("chat_memory_response_size", 100)) + .call(); + + return resp.content(); + } catch (Exception e) { + e.printStackTrace(); + return chunk.substring(0, Math.max(1, chunk.length() / 10)); // Fallback + } + } + +}