Merged PR 58: add summarize scenario

This commit is contained in:
D'Alia, Florinda
2025-02-17 10:58:10 +00:00
3 changed files with 236 additions and 1 deletions

10
pom.xml
View File

@@ -42,6 +42,11 @@
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-web</artifactId>
</dependency>
<dependency>
<groupId>com.knuddels</groupId>
<artifactId>jtokkit</artifactId>
<version>1.1.0</version>
</dependency>
<!--<dependency>
@@ -133,6 +138,11 @@
<artifactId>common</artifactId>
<version>1.0</version>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-tika-document-reader</artifactId>
</dependency>
</dependencies>

View File

@@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
import com.olympus.hermione.security.entity.User;
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
import com.olympus.hermione.stepSolvers.SummarizeDocSolver;
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
@@ -258,6 +259,9 @@ public class ScenarioExecutionService {
case "EXTERNAL_CODEGENIE":
solver = new ExternalCodeGenieSolver();
break;
case "SUMMARIZE_DOC":
solver = new SummarizeDocSolver();
break;
default:
break;
}
@@ -304,12 +308,13 @@ public class ScenarioExecutionService {
logger.info("Executing scenario: " + scenario.getName());
String folder_name = "";
ScenarioExecution scenarioExecution = new ScenarioExecution();
scenarioExecution.setScenario(scenario);
//prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload"
if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){
String folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
if(folder_name!=null && !folder_name.equals("")){
try{
String base64 = folderToBase64(folder_name);
@@ -319,6 +324,9 @@ public class ScenarioExecutionService {
}
}
}
if(scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")){
scenarioExecutionInput.getInputs().put("SingleFileUpload", uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload"));
}
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
try{

View File

@@ -0,0 +1,217 @@
package com.olympus.hermione.stepSolvers;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.nio.file.Files;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import org.apache.tika.Tika;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.InputStreamResource;
import org.springframework.util.MimeTypeUtils;
import org.springframework.ai.model.Media;
import com.knuddels.jtokkit.Encodings;
import com.knuddels.jtokkit.api.Encoding;
import com.knuddels.jtokkit.api.EncodingRegistry;
import java.io.File;
import java.io.FileInputStream;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class SummarizeDocSolver extends StepSolver {
private String qai_system_prompt_template;
private String qai_system_prompt_template_chunk;
private String qai_path_file;
private String qai_output_variable;
private int chunk_size_token;
private String encoding_name;
private double charMax;
private int tokenCount;
private Optional<Encoding> encoding;
private int percent_summarize;
private Integer maxTokenChunk;
private String qai_custom_memory_id;
// private boolean qai_load_graph_schema=false;
Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class);
private void loadParameters() {
logger.info("Loading parameters");
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.qai_system_prompt_template = attributeParser
.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
this.qai_system_prompt_template_chunk = attributeParser
.parse((String) this.step.getAttributes().get("qai_system_prompt_template_chunk"));
this.qai_path_file = attributeParser.parse((String) this.step.getAttributes().get("qai_path_file"));
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
this.encoding_name = (String) this.step.getAttributes().get("encoding_name");
this.chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("chunk_size_token"));
this.percent_summarize = Integer.parseInt((String) this.step.getAttributes().get("percent_summarize"));
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
}
@Override
public ScenarioExecution solveStep() {
System.out.println("Solving step: " + this.step.getName());
try {
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
// Estrazione del testo dal file
File file = new File(this.qai_path_file);
Tika tika = new Tika();
tika.setMaxStringLength(-1);
String text = tika.parseToString(file);
// add libreria tokenizer
// Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo)
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
encoding = registry.getEncoding(this.encoding_name);
// Conta i token
// int tokenCount = encoding.get().encode(text).size();
tokenCount = encoding.get().countTokens(text);
int charCount = text.length();
// Stima media caratteri per token
double charPerToken = (double) charCount / tokenCount;
Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0);
Integer output_char_int = output_char.intValue();
this.qai_system_prompt_template = this.qai_system_prompt_template.replace("number_char",
output_char_int.toString());
// **Fase di Summarization**
String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto
// String template = this.qai_system_prompt_template+" The output length should
// be of " + output_char + " characters";
logger.info("template: " + this.qai_system_prompt_template);
// Creazione dei messaggi per il modello AI
Message userMessage = new UserMessage(summarizedText);
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
CallResponseSpec resp = chatClient.prompt()
.messages(userMessage, systemMessage)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id",
this.scenarioExecution.getId() + this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.call();
String output = resp.content();
// Gestione del conteggio token usati
Usage usage = resp.chatResponse().getMetadata().getUsage();
if (usage != null) {
Long usedTokens = usage.getTotalTokens();
this.scenarioExecution.setUsedTokens(usedTokens);
} else {
logger.info("Token usage information is not available.");
}
// Salvataggio dell'output nel contesto di esecuzione
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
} catch (Exception e) {
logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage());
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error");
}
return this.scenarioExecution;
}
private String summarize(String text) {
// Se il testo è già corto, non riassumere
logger.info("length: " + text.length());
Double chunk_size_text;
int textLengthPlus = (int) (text.length() * 1.1);
int tokenCountPlus = (int) (tokenCount * 1.1);
// chunk_size_token/(ratio+10%)
// Double ratio = Math.floor((textLengthPlus / charMax) + 1);
Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1);
if (ratio == 1) {
return text;
} else {
chunk_size_text = Math.ceil(textLengthPlus / ratio);
}
// Suddividere il testo in chunk
List<String> chunks = chunkText(text, chunk_size_text.intValue());
List<String> summarizedChunks = new ArrayList<>();
Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio);
maxTokenChunk = maxTokenChunkD.intValue();
// Riassumere ogni chunk singolarmente
for (String chunk : chunks) {
logger.info("chunk length: " + chunk.length());
summarizedChunks.add(summarizeChunk(chunk));
}
// Unire i riassunti
String summarizedText = String.join(" ", summarizedChunks);
int tokenCountSummarizedText = encoding.get().countTokens(summarizedText);
// Se il riassunto è ancora troppo lungo, applicare ricorsione
if (tokenCountSummarizedText > chunk_size_token) {
return summarize(summarizedText);
} else {
return summarizedText;
}
}
// Metodo per suddividere il testo in chunk
private List<String> chunkText(String text, int chunkSize) {
List<String> chunks = new ArrayList<>();
int length = text.length();
for (int i = 0; i < length; i += chunkSize) {
chunks.add(text.substring(i, Math.min(length, i + chunkSize)));
}
return chunks;
}
// Metodo di riassunto per singolo chunk
private String summarizeChunk(String chunk) {
try {
this.qai_system_prompt_template_chunk = this.qai_system_prompt_template_chunk.replace("number_char",
maxTokenChunk.toString());
Message chunkMessage = new UserMessage(chunk);
Message systemMessage = new SystemMessage(qai_system_prompt_template_chunk);
CallResponseSpec resp = chatClient.prompt()
.messages(chunkMessage, systemMessage)
.advisors(advisor -> advisor
.param("chat_memory_conversation_id",
this.scenarioExecution.getId() + this.qai_custom_memory_id)
.param("chat_memory_response_size", 100))
.call();
return resp.content();
} catch (Exception e) {
e.printStackTrace();
return chunk.substring(0, Math.max(1, chunk.length() / 10)); // Fallback
}
}
}