add SummarizeDoc Step Solver
This commit is contained in:
@@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
|
|||||||
import com.olympus.hermione.repository.ScenarioRepository;
|
import com.olympus.hermione.repository.ScenarioRepository;
|
||||||
import com.olympus.hermione.security.entity.User;
|
import com.olympus.hermione.security.entity.User;
|
||||||
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
|
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
|
||||||
|
import com.olympus.hermione.stepSolvers.SummarizeDocSolver;
|
||||||
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
|
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
|
||||||
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
|
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
|
||||||
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
|
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
|
||||||
@@ -258,6 +259,9 @@ public class ScenarioExecutionService {
|
|||||||
case "EXTERNAL_CODEGENIE":
|
case "EXTERNAL_CODEGENIE":
|
||||||
solver = new ExternalCodeGenieSolver();
|
solver = new ExternalCodeGenieSolver();
|
||||||
break;
|
break;
|
||||||
|
case "SUMMARIZE_DOC":
|
||||||
|
solver = new SummarizeDocSolver();
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
@@ -304,12 +308,13 @@ public class ScenarioExecutionService {
|
|||||||
|
|
||||||
logger.info("Executing scenario: " + scenario.getName());
|
logger.info("Executing scenario: " + scenario.getName());
|
||||||
|
|
||||||
|
String folder_name = "";
|
||||||
|
|
||||||
ScenarioExecution scenarioExecution = new ScenarioExecution();
|
ScenarioExecution scenarioExecution = new ScenarioExecution();
|
||||||
scenarioExecution.setScenario(scenario);
|
scenarioExecution.setScenario(scenario);
|
||||||
//prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload"
|
//prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload"
|
||||||
if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){
|
if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){
|
||||||
String folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
|
folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
|
||||||
if(folder_name!=null && !folder_name.equals("")){
|
if(folder_name!=null && !folder_name.equals("")){
|
||||||
try{
|
try{
|
||||||
String base64 = folderToBase64(folder_name);
|
String base64 = folderToBase64(folder_name);
|
||||||
@@ -319,6 +324,9 @@ public class ScenarioExecutionService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if(scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")){
|
||||||
|
scenarioExecutionInput.getInputs().put("SingleFileUpload", uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload"));
|
||||||
|
}
|
||||||
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
|
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
|
||||||
|
|
||||||
try{
|
try{
|
||||||
|
|||||||
@@ -0,0 +1,217 @@
|
|||||||
|
package com.olympus.hermione.stepSolvers;
|
||||||
|
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.nio.file.Files;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.Base64;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import org.apache.tika.Tika;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
||||||
|
import org.springframework.ai.chat.messages.Message;
|
||||||
|
import org.springframework.ai.chat.messages.SystemMessage;
|
||||||
|
import org.springframework.ai.chat.messages.UserMessage;
|
||||||
|
import org.springframework.ai.chat.metadata.Usage;
|
||||||
|
import org.springframework.core.io.ClassPathResource;
|
||||||
|
import org.springframework.core.io.InputStreamResource;
|
||||||
|
import org.springframework.util.MimeTypeUtils;
|
||||||
|
import org.springframework.ai.model.Media;
|
||||||
|
import com.knuddels.jtokkit.Encodings;
|
||||||
|
import com.knuddels.jtokkit.api.Encoding;
|
||||||
|
import com.knuddels.jtokkit.api.EncodingRegistry;
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileInputStream;
|
||||||
|
|
||||||
|
import com.olympus.hermione.models.ScenarioExecution;
|
||||||
|
import com.olympus.hermione.utility.AttributeParser;
|
||||||
|
|
||||||
|
import ch.qos.logback.classic.Logger;
|
||||||
|
|
||||||
|
public class SummarizeDocSolver extends StepSolver {
|
||||||
|
|
||||||
|
private String qai_system_prompt_template;
|
||||||
|
private String qai_system_prompt_template_chunk;
|
||||||
|
private String qai_path_file;
|
||||||
|
private String qai_output_variable;
|
||||||
|
private int chunk_size_token;
|
||||||
|
private String encoding_name;
|
||||||
|
private double charMax;
|
||||||
|
private int tokenCount;
|
||||||
|
private Optional<Encoding> encoding;
|
||||||
|
private int percent_summarize;
|
||||||
|
private Integer maxTokenChunk;
|
||||||
|
private String qai_custom_memory_id;
|
||||||
|
// private boolean qai_load_graph_schema=false;
|
||||||
|
|
||||||
|
Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class);
|
||||||
|
|
||||||
|
private void loadParameters() {
|
||||||
|
logger.info("Loading parameters");
|
||||||
|
|
||||||
|
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||||
|
|
||||||
|
this.qai_system_prompt_template = attributeParser
|
||||||
|
.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||||
|
this.qai_system_prompt_template_chunk = attributeParser
|
||||||
|
.parse((String) this.step.getAttributes().get("qai_system_prompt_template_chunk"));
|
||||||
|
this.qai_path_file = attributeParser.parse((String) this.step.getAttributes().get("qai_path_file"));
|
||||||
|
|
||||||
|
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
||||||
|
this.encoding_name = (String) this.step.getAttributes().get("encoding_name");
|
||||||
|
|
||||||
|
this.chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("chunk_size_token"));
|
||||||
|
|
||||||
|
this.percent_summarize = Integer.parseInt((String) this.step.getAttributes().get("percent_summarize"));
|
||||||
|
|
||||||
|
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public ScenarioExecution solveStep() {
|
||||||
|
System.out.println("Solving step: " + this.step.getName());
|
||||||
|
|
||||||
|
try {
|
||||||
|
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||||
|
loadParameters();
|
||||||
|
|
||||||
|
// Estrazione del testo dal file
|
||||||
|
File file = new File(this.qai_path_file);
|
||||||
|
Tika tika = new Tika();
|
||||||
|
tika.setMaxStringLength(-1);
|
||||||
|
String text = tika.parseToString(file);
|
||||||
|
|
||||||
|
// add libreria tokenizer
|
||||||
|
// Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo)
|
||||||
|
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
|
||||||
|
encoding = registry.getEncoding(this.encoding_name);
|
||||||
|
|
||||||
|
// Conta i token
|
||||||
|
// int tokenCount = encoding.get().encode(text).size();
|
||||||
|
tokenCount = encoding.get().countTokens(text);
|
||||||
|
int charCount = text.length();
|
||||||
|
|
||||||
|
// Stima media caratteri per token
|
||||||
|
double charPerToken = (double) charCount / tokenCount;
|
||||||
|
|
||||||
|
Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0);
|
||||||
|
Integer output_char_int = output_char.intValue();
|
||||||
|
this.qai_system_prompt_template = this.qai_system_prompt_template.replace("number_char",
|
||||||
|
output_char_int.toString());
|
||||||
|
// **Fase di Summarization**
|
||||||
|
String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto
|
||||||
|
// String template = this.qai_system_prompt_template+" The output length should
|
||||||
|
// be of " + output_char + " characters";
|
||||||
|
logger.info("template: " + this.qai_system_prompt_template);
|
||||||
|
// Creazione dei messaggi per il modello AI
|
||||||
|
Message userMessage = new UserMessage(summarizedText);
|
||||||
|
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
|
||||||
|
|
||||||
|
CallResponseSpec resp = chatClient.prompt()
|
||||||
|
.messages(userMessage, systemMessage)
|
||||||
|
.advisors(advisor -> advisor
|
||||||
|
.param("chat_memory_conversation_id",
|
||||||
|
this.scenarioExecution.getId() + this.qai_custom_memory_id)
|
||||||
|
.param("chat_memory_response_size", 100))
|
||||||
|
.call();
|
||||||
|
|
||||||
|
String output = resp.content();
|
||||||
|
|
||||||
|
// Gestione del conteggio token usati
|
||||||
|
Usage usage = resp.chatResponse().getMetadata().getUsage();
|
||||||
|
if (usage != null) {
|
||||||
|
Long usedTokens = usage.getTotalTokens();
|
||||||
|
this.scenarioExecution.setUsedTokens(usedTokens);
|
||||||
|
} else {
|
||||||
|
logger.info("Token usage information is not available.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Salvataggio dell'output nel contesto di esecuzione
|
||||||
|
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||||
|
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage());
|
||||||
|
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error");
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.scenarioExecution;
|
||||||
|
}
|
||||||
|
|
||||||
|
private String summarize(String text) {
|
||||||
|
// Se il testo è già corto, non riassumere
|
||||||
|
logger.info("length: " + text.length());
|
||||||
|
Double chunk_size_text;
|
||||||
|
int textLengthPlus = (int) (text.length() * 1.1);
|
||||||
|
int tokenCountPlus = (int) (tokenCount * 1.1);
|
||||||
|
// chunk_size_token/(ratio+10%)
|
||||||
|
// Double ratio = Math.floor((textLengthPlus / charMax) + 1);
|
||||||
|
Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1);
|
||||||
|
if (ratio == 1) {
|
||||||
|
return text;
|
||||||
|
} else {
|
||||||
|
chunk_size_text = Math.ceil(textLengthPlus / ratio);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Suddividere il testo in chunk
|
||||||
|
List<String> chunks = chunkText(text, chunk_size_text.intValue());
|
||||||
|
List<String> summarizedChunks = new ArrayList<>();
|
||||||
|
|
||||||
|
Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio);
|
||||||
|
maxTokenChunk = maxTokenChunkD.intValue();
|
||||||
|
|
||||||
|
// Riassumere ogni chunk singolarmente
|
||||||
|
for (String chunk : chunks) {
|
||||||
|
logger.info("chunk length: " + chunk.length());
|
||||||
|
summarizedChunks.add(summarizeChunk(chunk));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unire i riassunti
|
||||||
|
String summarizedText = String.join(" ", summarizedChunks);
|
||||||
|
int tokenCountSummarizedText = encoding.get().countTokens(summarizedText);
|
||||||
|
|
||||||
|
// Se il riassunto è ancora troppo lungo, applicare ricorsione
|
||||||
|
if (tokenCountSummarizedText > chunk_size_token) {
|
||||||
|
return summarize(summarizedText);
|
||||||
|
} else {
|
||||||
|
return summarizedText;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metodo per suddividere il testo in chunk
|
||||||
|
private List<String> chunkText(String text, int chunkSize) {
|
||||||
|
List<String> chunks = new ArrayList<>();
|
||||||
|
int length = text.length();
|
||||||
|
for (int i = 0; i < length; i += chunkSize) {
|
||||||
|
chunks.add(text.substring(i, Math.min(length, i + chunkSize)));
|
||||||
|
}
|
||||||
|
return chunks;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Metodo di riassunto per singolo chunk
|
||||||
|
private String summarizeChunk(String chunk) {
|
||||||
|
try {
|
||||||
|
this.qai_system_prompt_template_chunk = this.qai_system_prompt_template_chunk.replace("number_char",
|
||||||
|
maxTokenChunk.toString());
|
||||||
|
|
||||||
|
Message chunkMessage = new UserMessage(chunk);
|
||||||
|
Message systemMessage = new SystemMessage(qai_system_prompt_template_chunk);
|
||||||
|
|
||||||
|
CallResponseSpec resp = chatClient.prompt()
|
||||||
|
.messages(chunkMessage, systemMessage)
|
||||||
|
.advisors(advisor -> advisor
|
||||||
|
.param("chat_memory_conversation_id",
|
||||||
|
this.scenarioExecution.getId() + this.qai_custom_memory_id)
|
||||||
|
.param("chat_memory_response_size", 100))
|
||||||
|
.call();
|
||||||
|
|
||||||
|
return resp.content();
|
||||||
|
} catch (Exception e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
return chunk.substring(0, Math.max(1, chunk.length() / 10)); // Fallback
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user