Merged PR 58: add summarize scenario
This commit is contained in:
10
pom.xml
10
pom.xml
@@ -42,6 +42,11 @@
|
||||
<groupId>org.springframework.boot</groupId>
|
||||
<artifactId>spring-boot-starter-web</artifactId>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>com.knuddels</groupId>
|
||||
<artifactId>jtokkit</artifactId>
|
||||
<version>1.1.0</version>
|
||||
</dependency>
|
||||
|
||||
|
||||
<!--<dependency>
|
||||
@@ -133,6 +138,11 @@
|
||||
<artifactId>common</artifactId>
|
||||
<version>1.0</version>
|
||||
</dependency>
|
||||
|
||||
<dependency>
|
||||
<groupId>org.springframework.ai</groupId>
|
||||
<artifactId>spring-ai-tika-document-reader</artifactId>
|
||||
</dependency>
|
||||
</dependencies>
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ import com.olympus.hermione.repository.ScenarioExecutionRepository;
|
||||
import com.olympus.hermione.repository.ScenarioRepository;
|
||||
import com.olympus.hermione.security.entity.User;
|
||||
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
|
||||
import com.olympus.hermione.stepSolvers.SummarizeDocSolver;
|
||||
import com.olympus.hermione.stepSolvers.BasicAIPromptSolver;
|
||||
import com.olympus.hermione.stepSolvers.BasicQueryRagSolver;
|
||||
import com.olympus.hermione.stepSolvers.QueryNeo4JSolver;
|
||||
@@ -258,6 +259,9 @@ public class ScenarioExecutionService {
|
||||
case "EXTERNAL_CODEGENIE":
|
||||
solver = new ExternalCodeGenieSolver();
|
||||
break;
|
||||
case "SUMMARIZE_DOC":
|
||||
solver = new SummarizeDocSolver();
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
@@ -304,12 +308,13 @@ public class ScenarioExecutionService {
|
||||
|
||||
logger.info("Executing scenario: " + scenario.getName());
|
||||
|
||||
String folder_name = "";
|
||||
|
||||
ScenarioExecution scenarioExecution = new ScenarioExecution();
|
||||
scenarioExecution.setScenario(scenario);
|
||||
//prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload"
|
||||
if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){
|
||||
String folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
|
||||
folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload");
|
||||
if(folder_name!=null && !folder_name.equals("")){
|
||||
try{
|
||||
String base64 = folderToBase64(folder_name);
|
||||
@@ -319,6 +324,9 @@ public class ScenarioExecutionService {
|
||||
}
|
||||
}
|
||||
}
|
||||
if(scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")){
|
||||
scenarioExecutionInput.getInputs().put("SingleFileUpload", uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload"));
|
||||
}
|
||||
scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput);
|
||||
|
||||
try{
|
||||
|
||||
@@ -0,0 +1,217 @@
|
||||
package com.olympus.hermione.stepSolvers;
|
||||
|
||||
import java.io.FileInputStream;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.nio.file.Files;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Base64;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
import org.apache.tika.Tika;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.ai.chat.client.ChatClient.CallResponseSpec;
|
||||
import org.springframework.ai.chat.messages.Message;
|
||||
import org.springframework.ai.chat.messages.SystemMessage;
|
||||
import org.springframework.ai.chat.messages.UserMessage;
|
||||
import org.springframework.ai.chat.metadata.Usage;
|
||||
import org.springframework.core.io.ClassPathResource;
|
||||
import org.springframework.core.io.InputStreamResource;
|
||||
import org.springframework.util.MimeTypeUtils;
|
||||
import org.springframework.ai.model.Media;
|
||||
import com.knuddels.jtokkit.Encodings;
|
||||
import com.knuddels.jtokkit.api.Encoding;
|
||||
import com.knuddels.jtokkit.api.EncodingRegistry;
|
||||
import java.io.File;
|
||||
import java.io.FileInputStream;
|
||||
|
||||
import com.olympus.hermione.models.ScenarioExecution;
|
||||
import com.olympus.hermione.utility.AttributeParser;
|
||||
|
||||
import ch.qos.logback.classic.Logger;
|
||||
|
||||
public class SummarizeDocSolver extends StepSolver {
|
||||
|
||||
private String qai_system_prompt_template;
|
||||
private String qai_system_prompt_template_chunk;
|
||||
private String qai_path_file;
|
||||
private String qai_output_variable;
|
||||
private int chunk_size_token;
|
||||
private String encoding_name;
|
||||
private double charMax;
|
||||
private int tokenCount;
|
||||
private Optional<Encoding> encoding;
|
||||
private int percent_summarize;
|
||||
private Integer maxTokenChunk;
|
||||
private String qai_custom_memory_id;
|
||||
// private boolean qai_load_graph_schema=false;
|
||||
|
||||
Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class);
|
||||
|
||||
private void loadParameters() {
|
||||
logger.info("Loading parameters");
|
||||
|
||||
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
|
||||
|
||||
this.qai_system_prompt_template = attributeParser
|
||||
.parse((String) this.step.getAttributes().get("qai_system_prompt_template"));
|
||||
this.qai_system_prompt_template_chunk = attributeParser
|
||||
.parse((String) this.step.getAttributes().get("qai_system_prompt_template_chunk"));
|
||||
this.qai_path_file = attributeParser.parse((String) this.step.getAttributes().get("qai_path_file"));
|
||||
|
||||
this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable");
|
||||
this.encoding_name = (String) this.step.getAttributes().get("encoding_name");
|
||||
|
||||
this.chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("chunk_size_token"));
|
||||
|
||||
this.percent_summarize = Integer.parseInt((String) this.step.getAttributes().get("percent_summarize"));
|
||||
|
||||
this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id");
|
||||
}
|
||||
|
||||
@Override
|
||||
public ScenarioExecution solveStep() {
|
||||
System.out.println("Solving step: " + this.step.getName());
|
||||
|
||||
try {
|
||||
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
|
||||
loadParameters();
|
||||
|
||||
// Estrazione del testo dal file
|
||||
File file = new File(this.qai_path_file);
|
||||
Tika tika = new Tika();
|
||||
tika.setMaxStringLength(-1);
|
||||
String text = tika.parseToString(file);
|
||||
|
||||
// add libreria tokenizer
|
||||
// Ottieni l'encoding del modello (es. cl100k_base per GPT-4 e GPT-3.5-turbo)
|
||||
EncodingRegistry registry = Encodings.newDefaultEncodingRegistry();
|
||||
encoding = registry.getEncoding(this.encoding_name);
|
||||
|
||||
// Conta i token
|
||||
// int tokenCount = encoding.get().encode(text).size();
|
||||
tokenCount = encoding.get().countTokens(text);
|
||||
int charCount = text.length();
|
||||
|
||||
// Stima media caratteri per token
|
||||
double charPerToken = (double) charCount / tokenCount;
|
||||
|
||||
Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0);
|
||||
Integer output_char_int = output_char.intValue();
|
||||
this.qai_system_prompt_template = this.qai_system_prompt_template.replace("number_char",
|
||||
output_char_int.toString());
|
||||
// **Fase di Summarization**
|
||||
String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto
|
||||
// String template = this.qai_system_prompt_template+" The output length should
|
||||
// be of " + output_char + " characters";
|
||||
logger.info("template: " + this.qai_system_prompt_template);
|
||||
// Creazione dei messaggi per il modello AI
|
||||
Message userMessage = new UserMessage(summarizedText);
|
||||
Message systemMessage = new SystemMessage(this.qai_system_prompt_template);
|
||||
|
||||
CallResponseSpec resp = chatClient.prompt()
|
||||
.messages(userMessage, systemMessage)
|
||||
.advisors(advisor -> advisor
|
||||
.param("chat_memory_conversation_id",
|
||||
this.scenarioExecution.getId() + this.qai_custom_memory_id)
|
||||
.param("chat_memory_response_size", 100))
|
||||
.call();
|
||||
|
||||
String output = resp.content();
|
||||
|
||||
// Gestione del conteggio token usati
|
||||
Usage usage = resp.chatResponse().getMetadata().getUsage();
|
||||
if (usage != null) {
|
||||
Long usedTokens = usage.getTotalTokens();
|
||||
this.scenarioExecution.setUsedTokens(usedTokens);
|
||||
} else {
|
||||
logger.info("Token usage information is not available.");
|
||||
}
|
||||
|
||||
// Salvataggio dell'output nel contesto di esecuzione
|
||||
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);
|
||||
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
|
||||
|
||||
} catch (Exception e) {
|
||||
logger.error("Error in AnswerQuestionOnDocSolver: " + e.getMessage());
|
||||
this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, "500 Internal Server Error");
|
||||
}
|
||||
|
||||
return this.scenarioExecution;
|
||||
}
|
||||
|
||||
private String summarize(String text) {
|
||||
// Se il testo è già corto, non riassumere
|
||||
logger.info("length: " + text.length());
|
||||
Double chunk_size_text;
|
||||
int textLengthPlus = (int) (text.length() * 1.1);
|
||||
int tokenCountPlus = (int) (tokenCount * 1.1);
|
||||
// chunk_size_token/(ratio+10%)
|
||||
// Double ratio = Math.floor((textLengthPlus / charMax) + 1);
|
||||
Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1);
|
||||
if (ratio == 1) {
|
||||
return text;
|
||||
} else {
|
||||
chunk_size_text = Math.ceil(textLengthPlus / ratio);
|
||||
}
|
||||
|
||||
// Suddividere il testo in chunk
|
||||
List<String> chunks = chunkText(text, chunk_size_text.intValue());
|
||||
List<String> summarizedChunks = new ArrayList<>();
|
||||
|
||||
Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio);
|
||||
maxTokenChunk = maxTokenChunkD.intValue();
|
||||
|
||||
// Riassumere ogni chunk singolarmente
|
||||
for (String chunk : chunks) {
|
||||
logger.info("chunk length: " + chunk.length());
|
||||
summarizedChunks.add(summarizeChunk(chunk));
|
||||
}
|
||||
|
||||
// Unire i riassunti
|
||||
String summarizedText = String.join(" ", summarizedChunks);
|
||||
int tokenCountSummarizedText = encoding.get().countTokens(summarizedText);
|
||||
|
||||
// Se il riassunto è ancora troppo lungo, applicare ricorsione
|
||||
if (tokenCountSummarizedText > chunk_size_token) {
|
||||
return summarize(summarizedText);
|
||||
} else {
|
||||
return summarizedText;
|
||||
}
|
||||
}
|
||||
|
||||
// Metodo per suddividere il testo in chunk
|
||||
private List<String> chunkText(String text, int chunkSize) {
|
||||
List<String> chunks = new ArrayList<>();
|
||||
int length = text.length();
|
||||
for (int i = 0; i < length; i += chunkSize) {
|
||||
chunks.add(text.substring(i, Math.min(length, i + chunkSize)));
|
||||
}
|
||||
return chunks;
|
||||
}
|
||||
|
||||
// Metodo di riassunto per singolo chunk
|
||||
private String summarizeChunk(String chunk) {
|
||||
try {
|
||||
this.qai_system_prompt_template_chunk = this.qai_system_prompt_template_chunk.replace("number_char",
|
||||
maxTokenChunk.toString());
|
||||
|
||||
Message chunkMessage = new UserMessage(chunk);
|
||||
Message systemMessage = new SystemMessage(qai_system_prompt_template_chunk);
|
||||
|
||||
CallResponseSpec resp = chatClient.prompt()
|
||||
.messages(chunkMessage, systemMessage)
|
||||
.advisors(advisor -> advisor
|
||||
.param("chat_memory_conversation_id",
|
||||
this.scenarioExecution.getId() + this.qai_custom_memory_id)
|
||||
.param("chat_memory_response_size", 100))
|
||||
.call();
|
||||
|
||||
return resp.content();
|
||||
} catch (Exception e) {
|
||||
e.printStackTrace();
|
||||
return chunk.substring(0, Math.max(1, chunk.length() / 10)); // Fallback
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Reference in New Issue
Block a user