diff --git a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java index 0ba2ebf..6961a44 100644 --- a/src/main/java/com/olympus/hermione/controllers/ExecutionController.java +++ b/src/main/java/com/olympus/hermione/controllers/ExecutionController.java @@ -24,15 +24,4 @@ public class ExecutionController { return scenarioExecutionRepository.findById(id).get(); } - // @PostMapping("/updateRating") - // public String updateScenarioExecRating(@RequestBody ScenarioExecution scenarioExecution){ - // String result = scenarioExecutionService.updateRating2(scenarioExecution); - // return result; - // } - - @GetMapping("/updateRating") - public String updateScenarioExecRating(@RequestParam String id, @RequestParam String rating){ - String result = scenarioExecutionService.updateRating(id, rating); - return result; - } } diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java index 9744fb6..581a62f 100644 --- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java +++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java @@ -1,13 +1,22 @@ package com.olympus.hermione.controllers; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; import java.util.List; +import java.util.Map; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.data.domain.Page; +import org.springframework.format.annotation.DateTimeFormat; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PathVariable; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RestController; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; import com.olympus.hermione.dto.ScenarioExecutionInput; import com.olympus.hermione.dto.ScenarioOutput; import com.olympus.hermione.models.Scenario; @@ -20,6 +29,7 @@ import com.olympus.model.Application; import com.olympus.model.Project; import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RequestParam; @RestController public class ScenarioController { @@ -33,6 +43,7 @@ public class ScenarioController { @Autowired ScenarioService scenarioService; + @GetMapping("/scenarios") public Iterable getScenarios() { return scenarioRepository.findAll(); @@ -88,17 +99,35 @@ public class ScenarioController { @GetMapping("/scenarios/getExecutionProgress/{id}") public ScenarioOutput getExecutionProgress(@PathVariable String id) { return scenarioExecutionService.getExecutionProgress(id); - } - + } @GetMapping("/scenarios/execute/{id}") public ScenarioExecution getScenarioExecution(@PathVariable String id) { - return scenarioExecutionRepository.findById(id).get(); + ScenarioExecution scenarioExecution = scenarioExecutionRepository.findById(id).get(); + String apiKey = scenarioExecution.getScenario().getAiModel().getApiKey(); + String endpoint = scenarioExecution.getScenario().getAiModel().getEndpoint(); + List availableForProject = scenarioExecution.getScenario().getAvailableForProjects(); + List availableForApplication = scenarioExecution.getScenario().getAvailableForApplications(); + if (apiKey != null) { + scenarioExecution.getScenario().getAiModel().setApiKey("**********"); } + if (endpoint != null) { + scenarioExecution.getScenario().getAiModel().setEndpoint("**********"); + } + if (availableForProject != null) { + scenarioExecution.getScenario().setAvailableForProjects(null); + } + if (availableForApplication != null) { + scenarioExecution.getScenario().setAvailableForApplications(null); + } + return scenarioExecution; +} - @GetMapping("/scenariosByUser") - public List getScenarioByUser() { - return scenarioExecutionService.getListExecutionScenarioByUser(); + @PostMapping("/executions") + public Page getAdvancedExecutions( + @RequestBody Map allParams) { + + return scenarioExecutionService.getFilteredExecutions(allParams); } @GetMapping("/getScenariosForRE") @@ -106,6 +135,10 @@ public class ScenarioController { return scenarioService.getListScenariosForRE(); } - + @GetMapping("/updateRating") + public String updateScenarioExecRating(@RequestParam String id, @RequestParam String rating){ + String result = scenarioService.updateRating(id, rating); + return result; + } } diff --git a/src/main/java/com/olympus/hermione/models/AiModel.java b/src/main/java/com/olympus/hermione/models/AiModel.java index aefeb7b..0dff9f1 100644 --- a/src/main/java/com/olympus/hermione/models/AiModel.java +++ b/src/main/java/com/olympus/hermione/models/AiModel.java @@ -19,5 +19,6 @@ public class AiModel { private Integer maxTokens; private String apiKey; private boolean isDefault = false; + private boolean isCanvasDefault = false; } diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java index 53d57a1..978606e 100644 --- a/src/main/java/com/olympus/hermione/models/Scenario.java +++ b/src/main/java/com/olympus/hermione/models/Scenario.java @@ -21,6 +21,7 @@ public class Scenario { private String id; private String name; private String description; + private String hint; private String startWithStepId; private List steps; private List inputs; @@ -36,6 +37,7 @@ public class Scenario { private boolean useChatMemory=false; private String outputType; - + private boolean canvasEnabled=false; + private boolean chatEnabled=false; } diff --git a/src/main/java/com/olympus/hermione/repository/AiModelRepository.java b/src/main/java/com/olympus/hermione/repository/AiModelRepository.java index b22ef14..e4333ad 100644 --- a/src/main/java/com/olympus/hermione/repository/AiModelRepository.java +++ b/src/main/java/com/olympus/hermione/repository/AiModelRepository.java @@ -1,13 +1,11 @@ package com.olympus.hermione.repository; import org.springframework.stereotype.Repository; - -import java.util.Optional; - import org.springframework.data.mongodb.repository.MongoRepository; - import com.olympus.hermione.models.AiModel; +import com.olympus.hermione.models.AiModel; @Repository public interface AiModelRepository extends MongoRepository{ AiModel findByIsDefault(Boolean isDefault); + AiModel findByIsCanvasDefault(Boolean isCanvasDefault); } diff --git a/src/main/java/com/olympus/hermione/repository/ScenarioExecutionRepository.java b/src/main/java/com/olympus/hermione/repository/ScenarioExecutionRepository.java index 6c7e3b6..47a1aad 100644 --- a/src/main/java/com/olympus/hermione/repository/ScenarioExecutionRepository.java +++ b/src/main/java/com/olympus/hermione/repository/ScenarioExecutionRepository.java @@ -2,6 +2,8 @@ package com.olympus.hermione.repository; import java.util.List; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.mongodb.repository.MongoRepository; import org.springframework.data.mongodb.repository.Query; @@ -21,14 +23,31 @@ public interface ScenarioExecutionRepository extends MongoRepository findByExecutedByUserIdAndInputs(String userId, String selectedProject, String selectedApplication); - @Query("{ $and: [ { 'executedByUserId': ?0 }, { 'scenarioExecutionInput.inputs.selected_project': ?1 }, { 'scenarioExecutionInput.inputs.selected_application': ?2 } ] }") - List getFromProjectAndAPP(String userId, String value1, String value2, Sort sort); - //findByExecutedByUserIdAndInputsOrderByStartDateDesc + // @Query("{ $and: [ { 'executedByUserId': ?0 }, { 'scenarioExecutionInput.inputs.selected_project': ?1 }, { 'scenarioExecutionInput.inputs.selected_application': ?2 } ] }") + // List getFromProjectAndAPP(String userId, String value1, String value2, Sort sort); + // //findByExecutedByUserIdAndInputsOrderByStartDateDesc - @Query("{ $and: [ { 'executedByUserId': ?0 }, { 'scenarioExecutionInput.inputs.selected_project': ?1 } ] }") - List getFromProject(String userId, String value1, Sort sort); + @Query("{ $and: [ { 'scenarioExecutionInput.inputs.selected_project': ?0 }, { 'scenarioExecutionInput.inputs.selected_application': ?1 } ] }") + Page getFromProjectAndAPP( String value1, String value2, Pageable pageable); + + @Query("{ $and: [ " + + " { 'scenarioExecutionInput.inputs.selected_project': ?0 }, " + + " { 'scenarioExecutionInput.inputs.selected_application': ?1 }, " + + " { $or: [ { 'scenarioExecutionInput.inputs.optional_field': ?2 }, { 'scenarioExecutionInput.inputs.optional_field': { $exists: false } } ] } " + + "] }") + Page getFromProjectAndAPPOptional( + String value1, + String value2, + String optionalValue, + Pageable pageable); + // @Query("{ $and: [ { 'executedByUserId': ?0 }, { 'scenarioExecutionInput.inputs.selected_project': ?1 } ] }") + // List getFromProject(String userId, String value1, Sort sort); + + @Query("{ $and: [ { 'scenarioExecutionInput.inputs.selected_project': ?0 } ] }") + Page getFromProject( String value1, Pageable pageable); + } diff --git a/src/main/java/com/olympus/hermione/repository/ScenarioRepository.java b/src/main/java/com/olympus/hermione/repository/ScenarioRepository.java index 7b5324f..ef0062c 100644 --- a/src/main/java/com/olympus/hermione/repository/ScenarioRepository.java +++ b/src/main/java/com/olympus/hermione/repository/ScenarioRepository.java @@ -19,20 +19,30 @@ public interface ScenarioRepository extends MongoRepository { List findByUsableFor(String projectId);*/ //List findByAvailableForProjects_Id(String projectId); - List findByAvailableForProjects_IdAndVisible(String projectId, String visible); + //List findByAvailableForProjects_IdAndVisible(String projectId, String visible); + List findByAvailableForProjects_IdAndVisibleIn(String projectId, List visible); + + List findByAvailableForApplications_IdAndVisibleIn(String projectId, List visible); + + List findByAvailableForProjectsIsNullAndAvailableForApplicationsIsNullAndVisibleIn(List visible); - List findByAvailableForApplications_IdAndVisible(String projectId, String visible); - - List findByAvailableForProjectsIsNullAndAvailableForApplicationsIsNullAndVisible(String visible); - - - @Query("{ 'visible': 'Y', 'category': 'RE', " + + /* @Query("{ 'visible': 'Y', 'category': 'RE', " + " $or: [ " + " { 'availableForProjects': ObjectId(?0) }, " + " { 'availableForApplications': ObjectId(?1) } " + " ] " + "}") List findByVisibleAndCategoryAndProjectOrApplication(String projectId, String applicationId); +*/ + +@Query("{ 'visible': { $in: ?2 }, 'category': 'RE', " + + " $or: [ " + + " { 'availableForProjects': ObjectId(?0) }, " + + " { 'availableForApplications': ObjectId(?1) } " + + " ] " + + "}") +List findByVisibleInAndCategoryAndProjectOrApplication(String projectId, String applicationId, List visibleValues); + } diff --git a/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java index ddd3732..17a5141 100644 --- a/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java @@ -1,91 +1,115 @@ package com.olympus.hermione.services; -import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; - +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.GoogleGeminiChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.dto.CanvasExecutionInput; import com.olympus.hermione.dto.CanvasOutput; +import com.olympus.hermione.models.AiModel; +import com.olympus.hermione.repository.AiModelRepository; @Service public class CanvasExecutionService { - - private final ChatModel chatModel; - @Autowired - public CanvasExecutionService(@Qualifier("openAiChatModel") ChatModel chatModel) { - this.chatModel = chatModel; - } + private AiModelRepository aiModelRepository; + + private OlympusChatClient olympusChatClient = null; private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class); - - public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please answer the question accurately."); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); - return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + public OlympusChatClient createCanvasClient(AiModelRepository aiModelRepository){ + AiModel aiModel = aiModelRepository.findByIsCanvasDefault(true); + if (aiModel == null) { + logger.error("No default AI model found"); + throw new IllegalArgumentException("No default AI model found"); } - return canvasOutput; - + if (aiModel.getApiProvider().equals("AzureOpenAI")) { + olympusChatClient = new AzureOpenApiChatClient(); + } + if (aiModel.getApiProvider().equals("GoogleGemini")) { + olympusChatClient = new GoogleGeminiChatClient(); + } + olympusChatClient.init(aiModel.getEndpoint(), aiModel.getApiKey(), aiModel.getMaxTokens()); + + return olympusChatClient; + } + + public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Answer the following question: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText +"\n"+ userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } + return canvasOutput; + } catch (Exception e){ + logger.error("Error in askHermione: " + e.getMessage()); + return null; + } } public CanvasOutput rephraseText(CanvasExecutionInput canvasExecutionInput) { - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please rephrase the following text: "); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Rephrase the following text: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + } catch (Exception e){ + logger.error("Error in rephraseText: " + e.getMessage()); + return null; } - return canvasOutput; } + public CanvasOutput summarizeText(CanvasExecutionInput canvasExecutionInput) { - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please summarize the following text: "); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Summarize the following text: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + } catch (Exception e){ + logger.error("Error in summarizeText: " + e.getMessage()); + return null; } - return canvasOutput; } } diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index a03a2c3..a6f8229 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -4,14 +4,18 @@ import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.nio.file.Files; +import java.time.LocalDate; +import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalInt; import java.util.stream.IntStream; import java.util.zip.ZipEntry; import java.util.zip.ZipOutputStream; +import org.springframework.data.domain.PageRequest; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.AzureOpenAiChatModel; @@ -29,7 +33,14 @@ import org.springframework.ai.vectorstore.VectorStore; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Value; import org.springframework.cloud.client.discovery.DiscoveryClient; +import org.springframework.data.domain.Page; +import org.springframework.data.domain.PageImpl; +import org.springframework.data.domain.PageRequest; +import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; +import org.springframework.data.mongodb.core.MongoTemplate; +import org.springframework.data.mongodb.core.query.Criteria; +import org.springframework.data.mongodb.core.query.Query; import org.springframework.scheduling.annotation.Async; import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Service; @@ -55,6 +66,7 @@ import com.olympus.hermione.stepSolvers.QueryNeo4JSolver; import com.olympus.hermione.stepSolvers.SourceCodeRagSolver; import com.olympus.hermione.stepSolvers.StepSolver; +import org.bson.types.ObjectId; import org.neo4j.driver.Driver; import org.slf4j.Logger; @@ -65,18 +77,18 @@ import com.olympus.hermione.stepSolvers.ExternalAgentSolver; import com.olympus.hermione.stepSolvers.ExternalCodeGenieSolver; import com.olympus.hermione.stepSolvers.OlynmpusChatClientSolver; - - - @Service public class ScenarioExecutionService { @Value("${file.upload-dir}") private String uploadDir; - + + @Autowired + private MongoTemplate mongoTemplate; + @Autowired private ScenarioRepository scenarioRepository; - + @Autowired private ScenarioExecutionRepository scenarioExecutionRepository; @@ -85,10 +97,9 @@ public class ScenarioExecutionService { @Autowired VectorStore vectorStore; - - @Autowired - DiscoveryClient discoveryClient; + @Autowired + DiscoveryClient discoveryClient; @Autowired Driver graphDriver; @@ -101,44 +112,49 @@ public class ScenarioExecutionService { private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class); - public ScenarioOutput getExecutionProgress(String scenarioExecutionId){ + public ScenarioOutput getExecutionProgress(String scenarioExecutionId) { ScenarioOutput scenarioOutput = new ScenarioOutput(); - Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId); - - if(o_scenarioExecution.isPresent()){ + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionId); + + if (o_scenarioExecution.isPresent()) { ScenarioExecution scenarioExecution = o_scenarioExecution.get(); scenarioExecution.setCurrentStepId(scenarioExecution.getCurrentStepId()); - if(scenarioExecution.getExecSharedMap()!=null && scenarioExecution.getExecSharedMap().get("scenario_output")!=null){ + if (scenarioExecution.getExecSharedMap() != null + && scenarioExecution.getExecSharedMap().get("scenario_output") != null) { scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setStatus("OK"); scenarioOutput.setStringOutput(scenarioExecution.getExecSharedMap().get("scenario_output").toString()); scenarioOutput.setMessage(""); - }else{ + } else { scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setStatus("IN_PROGRESS"); - if (scenarioExecution.getCurrentStepId() != null){ + if (scenarioExecution.getCurrentStepId() != null) { OptionalInt indexOpt = IntStream.range(0, scenarioExecution.getScenario().getSteps().size()) - .filter(i -> scenarioExecution.getScenario().getSteps().get(i).getStepId().equals(scenarioExecution.getCurrentStepId())) - .map(i -> i + 1) - .findFirst(); + .filter(i -> scenarioExecution.getScenario().getSteps().get(i).getStepId() + .equals(scenarioExecution.getCurrentStepId())) + .map(i -> i + 1) + .findFirst(); scenarioOutput.setCurrentStepId(scenarioExecution.getCurrentStepId()); scenarioOutput.setCurrentStepDescription(scenarioExecution.getCurrentStepDescription()); - scenarioOutput.setMessage("Executing step "+indexOpt.getAsInt()+"/"+scenarioExecution.getScenario().getSteps().size()+": "+scenarioOutput.getCurrentStepDescription()+"."); - }else { + scenarioOutput.setMessage("Executing step " + indexOpt.getAsInt() + "/" + + scenarioExecution.getScenario().getSteps().size() + ": " + + scenarioOutput.getCurrentStepDescription() + "."); + } else { scenarioOutput.setMessage("Starting execution..."); } } - - if(scenarioExecution.getLatestStepStatus()!= null && scenarioExecution.getLatestStepStatus().equals("ERROR")){ + + if (scenarioExecution.getLatestStepStatus() != null + && scenarioExecution.getLatestStepStatus().equals("ERROR")) { scenarioOutput.setScenarioExecution_id(scenarioExecution.getId()); scenarioOutput.setStatus("ERROR"); scenarioOutput.setMessage(scenarioExecution.getLatestStepOutput()); } - - }else{ + + } else { scenarioOutput.setScenarioExecution_id(null); scenarioOutput.setStatus("ERROR"); scenarioOutput.setMessage("Scenario Execution not found"); @@ -148,16 +164,16 @@ public class ScenarioExecutionService { } @Async - public void executeScenarioAsync(ScenarioOutput scenarioOutput){ + public void executeScenarioAsync(ScenarioOutput scenarioOutput) { String scenarioExecutionID = scenarioOutput.getScenarioExecution_id(); - Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID); + Optional o_scenarioExecution = scenarioExecutionRepository.findById(scenarioExecutionID); - if(o_scenarioExecution.isPresent()){ + if (o_scenarioExecution.isPresent()) { ScenarioExecution scenarioExecution = o_scenarioExecution.get(); - HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs(); - Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId()); + HashMap inputs = scenarioExecution.getScenarioExecutionInput().getInputs(); + Optional o_scenario = scenarioRepository.findById(scenarioExecution.getScenario().getId()); Scenario scenario = o_scenario.get(); @@ -167,81 +183,76 @@ public class ScenarioExecutionService { HashMap execSharedMap = new HashMap(); execSharedMap.put("user_input", inputs); scenarioExecution.setExecSharedMap(execSharedMap); - + AiModel aiModel; - if(scenario.getAiModel() != null){ - aiModel = scenario.getAiModel(); - }else { + if (scenario.getAiModel() != null) { + aiModel = scenario.getAiModel(); + } else { aiModel = aiModelRepository.findByIsDefault(true); scenario.setAiModel(aiModel); scenarioExecution.setScenario(scenario); } - - ChatModel chatModel = createChatModel(aiModel); - + + ChatModel chatModel = createChatModel(aiModel); + scenarioExecutionRepository.save(scenarioExecution); - - if(scenario.isUseChatMemory()){ + if (scenario.isUseChatMemory()) { logger.info("Initializing chatClient with chat-memory advisor"); ChatMemory chatMemory = new InMemoryChatMemory(); chatClient = ChatClient.builder(chatModel) - .defaultAdvisors( - new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor - new SimpleLoggerAdvisor() - ) - .build(); + .defaultAdvisors( + new MessageChatMemoryAdvisor(chatMemory), // chat-memory advisor + new SimpleLoggerAdvisor()) + .build(); - }else{ + } else { logger.info("Initializing chatClient with simple logger advisor"); chatClient = ChatClient.builder(chatModel) - .defaultAdvisors( - new SimpleLoggerAdvisor() - ) - .build(); + .defaultAdvisors( + new SimpleLoggerAdvisor()) + .build(); } - List steps = scenario.getSteps(); - String startStepId=scenario.getStartWithStepId(); + String startStepId = scenario.getStartWithStepId(); + + ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst() + .orElse(null); - ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); - executeScenarioStep(startStep, scenarioExecution); Integer usedTokens = (scenarioExecution.getUsedTokens() != null) ? scenarioExecution.getUsedTokens() : 0; - while (scenarioExecution.getNextStepId()!=null) { - ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); + while (scenarioExecution.getNextStepId() != null) { + ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())) + .findFirst().orElse(null); executeScenarioStep(step, scenarioExecution); if (scenarioExecution.getUsedTokens() != null && scenarioExecution.getUsedTokens() != 0) { usedTokens += scenarioExecution.getUsedTokens(); - scenarioExecution.setUsedTokens(0); //resetting value for next step if is not an AI step + scenarioExecution.setUsedTokens(0); // resetting value for next step if is not an AI step } - if(scenarioExecution.getLatestStepStatus() != null && scenarioExecution.getLatestStepStatus().equals("ERROR")){ + if (scenarioExecution.getLatestStepStatus() != null + && scenarioExecution.getLatestStepStatus().equals("ERROR")) { logger.error("Error while executing step: " + step.getStepId()); break; } } - + scenarioExecution.setUsedTokens(usedTokens); scenarioExecution.setEndDate(new java.util.Date()); scenarioExecutionRepository.save(scenarioExecution); } } - - - - - private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution){ + private void executeScenarioStep(ScenarioStep step, ScenarioExecution scenarioExecution) { logger.info("Start working on step: " + step.getName() + " with type: " + step.getType()); - StepSolver solver=new StepSolver(); - switch (step.getType()) { + StepSolver solver = new StepSolver(); + switch (step.getType()) { case "ADVANCED_QUERY_AI": solver = new AdvancedAIPromptSolver(); break; @@ -282,88 +293,86 @@ public class ScenarioExecutionService { break; } - logger.info("Initializing step: " + step.getStepId()); logger.info("Solving step: " + step.getStepId()); scenarioExecution.setCurrentStepId(step.getStepId()); scenarioExecution.setCurrentStepDescription(step.getName()); - ScenarioExecution scenarioExecutionNew=scenarioExecution; + ScenarioExecution scenarioExecutionNew = scenarioExecution; - try{ + try { solver.init(step, scenarioExecution, - vectorStore,chatModel,chatClient,graphDriver,neo4JUitilityService - ,discoveryClient); + vectorStore, chatModel, chatClient, graphDriver, neo4JUitilityService, discoveryClient); scenarioExecutionNew = solver.solveStep(); logger.info("Step solved: " + step.getStepId()); logger.info("Next step: " + scenarioExecutionNew.getNextStepId()); - }catch (Exception e){ + } catch (Exception e) { logger.error("Error while solving step: " + step.getStepId() + " => " + e.getMessage()); scenarioExecutionNew.setNextStepId(null); scenarioExecutionNew.setLatestStepStatus("ERROR"); scenarioExecutionNew.setLatestStepOutput(e.getMessage()); } - scenarioExecutionRepository.save(scenarioExecutionNew); } - - public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput){ + public ScenarioOutput prepareScrenarioExecution(ScenarioExecutionInput scenarioExecutionInput) { String scenarioId = scenarioExecutionInput.getScenario_id(); ScenarioOutput scenarioOutput = new ScenarioOutput(); - Optional o_scenario = scenarioRepository.findById(scenarioId); + Optional o_scenario = scenarioRepository.findById(scenarioId); - if(o_scenario.isPresent()){ + if (o_scenario.isPresent()) { Scenario scenario = o_scenario.get(); logger.info("Executing scenario: " + scenario.getName()); - + String folder_name = ""; ScenarioExecution scenarioExecution = new ScenarioExecution(); scenarioExecution.setScenario(scenario); - //prendi i file dalla cartella temporanea se è presente una chiave con name "MultiFileUpload" - if(scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")){ + // prendi i file dalla cartella temporanea se è presente una chiave con name + // "MultiFileUpload" + if (scenarioExecutionInput.getInputs().containsKey("MultiFileUpload")) { folder_name = scenarioExecutionInput.getInputs().get("MultiFileUpload"); - if(folder_name!=null && !folder_name.equals("")){ - try{ + if (folder_name != null && !folder_name.equals("")) { + try { String base64 = folderToBase64(folder_name); scenarioExecutionInput.getInputs().put("MultiFileUpload", base64); - }catch(Exception e){ + } catch (Exception e) { logger.error("Error while converting folder to base64: " + e.getMessage()); } } } - if(scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")){ - scenarioExecutionInput.getInputs().put("SingleFileUpload", uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload")); + if (scenarioExecutionInput.getInputs().containsKey("SingleFileUpload")) { + scenarioExecutionInput.getInputs().put("SingleFileUpload", + uploadDir + folder_name + "/" + scenarioExecutionInput.getInputs().get("SingleFileUpload")); } scenarioExecution.setScenarioExecutionInput(scenarioExecutionInput); - try{ - if( SecurityContextHolder.getContext() != null - && SecurityContextHolder.getContext().getAuthentication() != null - && SecurityContextHolder.getContext().getAuthentication().getPrincipal()!= null - ){ - + try { + if (SecurityContextHolder.getContext() != null + && SecurityContextHolder.getContext().getAuthentication() != null + && SecurityContextHolder.getContext().getAuthentication().getPrincipal() != null) { + User principal = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); scenarioExecution.setExecutedByUserId(principal.getId().toString()); scenarioExecution.setExecutedByUsername(principal.getUsername()); - if(principal.getSelectedApplication()!=null){ - scenarioExecutionInput.getInputs().put("selected_application", principal.getSelectedApplication().getInternal_name()); + if (principal.getSelectedApplication() != null) { + scenarioExecutionInput.getInputs().put("selected_application", + principal.getSelectedApplication().getInternal_name()); } - scenarioExecutionInput.getInputs().put("selected_project", principal.getSelectedProject().getInternal_name()); + scenarioExecutionInput.getInputs().put("selected_project", + principal.getSelectedProject().getInternal_name()); } - }catch(Exception e){ + } catch (Exception e) { logger.error("Error while saving user information in scenario execution: " + e.getMessage()); } - scenarioExecutionRepository.save(scenarioExecution); @@ -375,7 +384,7 @@ public class ScenarioExecutionService { return scenarioOutput; } - public String folderToBase64(String folderName) throws Exception { + public String folderToBase64(String folderName) throws Exception { File folder = new File(uploadDir, folderName); if (!folder.exists() || !folder.isDirectory()) { throw new IllegalArgumentException("La cartella specificata non esiste o non è una directory"); @@ -385,7 +394,7 @@ public class ScenarioExecutionService { File zipFile = File.createTempFile("folder", ".zip"); try (FileOutputStream fos = new FileOutputStream(zipFile); - ZipOutputStream zos = new ZipOutputStream(fos)) { + ZipOutputStream zos = new ZipOutputStream(fos)) { zipFolder(folder, folder.getName(), zos); } @@ -425,107 +434,257 @@ public class ScenarioExecutionService { } } - - private ChatModel createChatModel(AiModel aiModel){ - switch(aiModel.getApiProvider()){ + private ChatModel createChatModel(AiModel aiModel) { + switch (aiModel.getApiProvider()) { case "AzureOpenAI": - OpenAIClientBuilder openAIClient = new OpenAIClientBuilder() - .credential(new AzureKeyCredential(aiModel.getApiKey())) - .endpoint(aiModel.getEndpoint()); - + OpenAIClientBuilder openAIClient = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(aiModel.getApiKey())) + .endpoint(aiModel.getEndpoint()); AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder() - .deploymentName(aiModel.getModel()) - .maxTokens(aiModel.getMaxTokens()) - .temperature(Double.valueOf(aiModel.getTemperature())) - .build(); - + .deploymentName(aiModel.getModel()) + .maxTokens(aiModel.getMaxTokens()) + .temperature(Double.valueOf(aiModel.getTemperature())) + .build(); + AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions); logger.info("AI model used: " + aiModel.getModel()); return azureOpenaichatModel; - - case "OpenAI": OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey()); OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder() - .model(aiModel.getModel()) - .temperature(Double.valueOf(aiModel.getTemperature())) - .maxTokens(aiModel.getMaxTokens()) - .build(); - - OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions); - logger.info("AI model used: " + aiModel.getModel()); - return openaichatModel; + .model(aiModel.getModel()) + .temperature(Double.valueOf(aiModel.getTemperature())) + .maxTokens(aiModel.getMaxTokens()) + .build(); + OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi, openAiChatOptions); + logger.info("AI model used: " + aiModel.getModel()); + return openaichatModel; case "GoogleGemini": - OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder() + OpenAIClientBuilder openAIClient2 = new OpenAIClientBuilder() .credential(new AzureKeyCredential(aiModel.getApiKey())) .endpoint(aiModel.getEndpoint()); - - AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder() - .deploymentName(aiModel.getModel()) + AzureOpenAiChatOptions openAIChatOptions2 = AzureOpenAiChatOptions.builder() + .deploymentName(aiModel.getModel()) .maxTokens(aiModel.getMaxTokens()) .temperature(Double.valueOf(aiModel.getTemperature())) .build(); - AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2, openAIChatOptions2); + AzureOpenAiChatModel azureOpenaichatModel2 = new AzureOpenAiChatModel(openAIClient2, + openAIChatOptions2); - logger.info("AI model used : " + aiModel.getModel()); - return azureOpenaichatModel2; + logger.info("AI model used : " + aiModel.getModel()); + return azureOpenaichatModel2; default: throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName()); } } - public List getListExecutionScenarioByUser(){ - logger.info("getListProjectByUser function:"); + /* + * public List getListExecutionScenario(){ + * logger.info("getListProjectByUser function:"); + * + * List lstScenarioExecution = null; + * + * User principal = (User) + * SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + * + * if(principal.getSelectedApplication()!=null){ + * lstScenarioExecution = + * scenarioExecutionRepository.getFromProjectAndAPP(principal.getId(), + * principal.getSelectedProject().getInternal_name(), + * principal.getSelectedApplication().getInternal_name(), + * Sort.by(Sort.Direction.DESC, "startDate")); + * + * }else{ + * lstScenarioExecution = + * scenarioExecutionRepository.getFromProject(principal.getId(), + * principal.getSelectedProject().getInternal_name(), + * Sort.by(Sort.Direction.DESC, "startDate")); + * } + * + * return lstScenarioExecution; + * + * } + */ - List lstScenarioExecution = null; + public Page getListExecutionScenarioOptional(int page, int size, String scenarioName, + String executedBy, LocalDate startDate) { + logger.info("getListExecutionScenarioOptional function:"); + + List criteriaList = new ArrayList<>(); + User principal = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + + // Filtro per progetto selezionato + if (principal.getSelectedProject() != null && principal.getSelectedProject().getId() != null) { + criteriaList.add(Criteria.where("scenarioExecutionInput.inputs.selected_project") + .is(principal.getSelectedProject().getInternal_name())); + } + + // Filtro per applicazione selezionata + if (principal.getSelectedApplication() != null && principal.getSelectedApplication().getId() != null) { + criteriaList.add(Criteria.where("scenarioExecutionInput.inputs.selected_application") + .is(principal.getSelectedApplication().getInternal_name())); + } + + // Filtro per nome scenario + if (scenarioName != null) { + criteriaList.add(Criteria.where("scenarioExecutionInput.scenario.name").is(scenarioName)); + } + + // Filtro per utente che ha eseguito + if (executedBy != null) { + criteriaList.add(Criteria.where("executedByUsername").is(executedBy)); + } + + // Filtro per data di inizio + if (startDate != null) { + criteriaList.add(Criteria.where("startDate").gte(startDate)); + } + + // Costruisce la query solo con i criteri presenti + Criteria criteria = new Criteria(); + if (!criteriaList.isEmpty()) { + criteria = new Criteria().andOperator(criteriaList.toArray(new Criteria[0])); + } + + Query query = new Query(criteria); + long count = mongoTemplate.count(query, ScenarioExecution.class); // Conta i documenti senza paginazione + + // Applica la paginazione solo per ottenere i risultati + Pageable pageable = PageRequest.of(page, size, Sort.by(Sort.Direction.DESC, "startDate")); + query.with(pageable); + + List results = mongoTemplate.find(query, ScenarioExecution.class); + + return new PageImpl<>(results, pageable, count); + } + + public Page getFilteredExecutions(Map filters) { + logger.info("getFilteredExecutions function:"); + + List criteriaList = new ArrayList<>(); + + int page = 0; + int size = 10; + String sortField = "startDate"; + int sortOrder = -1; User principal = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); - - if(principal.getSelectedApplication()!=null){ - lstScenarioExecution = scenarioExecutionRepository.getFromProjectAndAPP(principal.getId(), principal.getSelectedProject().getInternal_name(), principal.getSelectedApplication().getInternal_name(), Sort.by(Sort.Direction.DESC, "startDate")); - - }else{ - lstScenarioExecution = scenarioExecutionRepository.getFromProject(principal.getId(), principal.getSelectedProject().getInternal_name(), Sort.by(Sort.Direction.DESC, "startDate")); - } + criteriaList.add(Criteria.where("execSharedMap.user_input.selected_project") + .is(principal.getSelectedProject().getInternal_name())); + criteriaList.add(Criteria.where("execSharedMap.user_input.selected_application") + .is(principal.getSelectedApplication().getInternal_name())); + // here: + // filtri + // here: + // here: vedi perchè non funziona link di view + // here: da sistemare filtro data a backend e i vari equals + // here: add ordinamento - return lstScenarioExecution; + for (Map.Entry entry : filters.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); - } + if ("page".equals(key)) { + page = (int) value; + } else if ("size".equals(key)) { + size = (int) value; + }else if ("sortField".equals(key)) { + if(value!=null) { + sortField = (String) value; + } - public String updateRating2(ScenarioExecution scenaExec){ - logger.info("updateRating function:"); - String result = "KO"; - try{ - scenarioExecutionRepository.save(scenaExec); - result = "OK"; - }catch(Exception e){ - logger.error("Exception in updateRating: {}", e.getMessage()); - } + }else if("sortOrder".equals(key)){ + if(value!=null) { + sortOrder = (int) value; + } + }else if (value instanceof Map) { + Map valueMap = (Map) value; + String operator = (String) valueMap.get("operator"); + List> constraints = (List>) valueMap.get("constraints"); - return result; + for (Map constraint : constraints) { + Object constraintValue = constraint.get("value"); + String matchMode = (String) constraint.get("matchMode"); - } + if (constraintValue != null && !constraintValue.toString().isEmpty()) { + switch (matchMode) { + case "contains": + if (key.equals("_id")) { + criteriaList.add(Criteria.where("_id").is(new ObjectId((String) constraintValue))); + } else { + criteriaList.add(Criteria.where(key).regex((String) constraintValue, "i")); + } + break; + case "notContains": + criteriaList.add(Criteria.where(key).not().regex((String) constraintValue, "i")); + break; + case "dateIs": + criteriaList.add(Criteria.where(key).is( + LocalDate.parse(((String) constraintValue).substring(0, 10)))); - public String updateRating(String id, String rating){ - logger.info("updateRating function:"); - String result = "KO"; - try{ - Optional o_scenarioExecution = scenarioExecutionRepository.findById(id); - if(o_scenarioExecution.isPresent()){ - o_scenarioExecution.get().setRating(rating); - scenarioExecutionRepository.save(o_scenarioExecution.get()); - result = "OK"; + break; + case "dateIsNot": + criteriaList.add(Criteria.where(key).ne(Criteria.where(key).is( + LocalDate.parse(((String) constraintValue).substring(0, 10))))); + break; + case "dateBefore": + criteriaList.add(Criteria.where(key).lt(Criteria.where(key).is( + LocalDate.parse(((String) constraintValue).substring(0, 10))))); + break; + case "dateAfter": + criteriaList.add(Criteria.where(key).gt(Criteria.where(key).is( + LocalDate.parse(((String) constraintValue).substring(0, 10))))); + break; + case "equals": + criteriaList.add(Criteria.where(key).is(constraintValue)); + break; + case "notEquals": + criteriaList.add(Criteria.where(key).ne(constraintValue)); + break; + case "startsWith": + criteriaList.add(Criteria.where(key).regex("^" + constraintValue, "i")); + break; + case "endsWith": + criteriaList.add(Criteria.where(key).regex(constraintValue + "$", "i")); + break; + default: + criteriaList.add(Criteria.where(key).is(constraintValue)); + break; + } + } + } + } else if (value instanceof String) { + criteriaList.add(Criteria.where(key).regex((String) value, "i")); } - }catch(Exception e){ - logger.error("Exception in updateRating: {}", e.getMessage()); } - return result; + + Criteria criteria = new Criteria(); + if (!criteriaList.isEmpty()) { + criteria = new Criteria().andOperator(criteriaList.toArray(new Criteria[0])); + } + + Query query = new Query(criteria); + long count = mongoTemplate.count(query, ScenarioExecution.class); + + Pageable pageable = null; + if(sortOrder == -1) { + pageable = PageRequest.of(page, size, Sort.by(Sort.Direction.DESC, sortField)); + }else { + pageable = PageRequest.of(page, size, Sort.by(Sort.Direction.ASC, sortField)); + } + query.with(pageable); + + List results = mongoTemplate.find(query, ScenarioExecution.class); + + return new PageImpl<>(results, pageable, count); } + + } diff --git a/src/main/java/com/olympus/hermione/services/ScenarioService.java b/src/main/java/com/olympus/hermione/services/ScenarioService.java index 1916439..6f71acc 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioService.java @@ -7,13 +7,17 @@ import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.stereotype.Service; import com.olympus.hermione.models.Scenario; +import com.olympus.hermione.models.ScenarioExecution; import com.olympus.hermione.repository.ProjectRepository; +import com.olympus.hermione.repository.ScenarioExecutionRepository; import com.olympus.hermione.repository.ScenarioRepository; import com.olympus.hermione.security.entity.User; import com.olympus.model.Project; import java.util.ArrayList; import java.util.List; +import java.util.Optional; + import org.bson.types.ObjectId; @Service @@ -24,11 +28,24 @@ public class ScenarioService { @Autowired private ScenarioRepository scenarioRepo; + @Autowired + private ScenarioExecutionRepository scenarioExecutionRepository; + public List getListScenariosByProject(String project){ logger.info("getListProjectByUser function:"); List lstScenarios = null; + + List scenAccept = new ArrayList(); + scenAccept.add("Y"); + try{ - lstScenarios = scenarioRepo.findByAvailableForProjects_IdAndVisible(project, "Y"); + + User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + if(user.getRole().name().equals("ADMIN")){ + scenAccept.add("DRAFT"); + } + lstScenarios = scenarioRepo.findByAvailableForProjects_IdAndVisibleIn(project, scenAccept); + //lstScenarios = scenarioRepo.findByAvailableForProjects_IdAndVisible(project, "Y"); }catch(Exception e){ logger.error("Exception ScenarioRepository:", e.getMessage()); } @@ -41,8 +58,16 @@ public class ScenarioService { public List getListScenariosByApplication(String app){ logger.info("getListProjectByUser function:"); List lstScenarios = null; + List scenAccept = new ArrayList(); + scenAccept.add("Y"); try{ - lstScenarios = scenarioRepo.findByAvailableForApplications_IdAndVisible(app, "Y"); + + User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + if(user.getRole().name().equals("ADMIN")){ + scenAccept.add("DRAFT"); + } + lstScenarios = scenarioRepo.findByAvailableForApplications_IdAndVisibleIn(app, scenAccept); + // }catch(Exception e){ logger.error("Exception ScenarioRepository:", e.getMessage()); } @@ -55,8 +80,16 @@ public class ScenarioService { public List getListScenariosCross(){ logger.info("getListProjectByUser function:"); List lstScenarios = null; + + List scenAccept = new ArrayList(); + scenAccept.add("Y"); try{ - lstScenarios = scenarioRepo.findByAvailableForProjectsIsNullAndAvailableForApplicationsIsNullAndVisible("Y"); + + User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); + if(user.getRole().name().equals("ADMIN")){ + scenAccept.add("DRAFT"); + } + lstScenarios = scenarioRepo.findByAvailableForProjectsIsNullAndAvailableForApplicationsIsNullAndVisibleIn(scenAccept); }catch(Exception e){ logger.error("Exception ScenarioRepository:", e.getMessage()); } @@ -72,10 +105,16 @@ public class ScenarioService { Scenario scenarioDefault = new Scenario(); scenarioDefault.setName("Default"); scenarioDefault.setId(""); + + List scenAccept = new ArrayList(); + scenAccept.add("Y"); try{ User user = (User) SecurityContextHolder.getContext().getAuthentication().getPrincipal(); - lstScenarios = scenarioRepo.findByVisibleAndCategoryAndProjectOrApplication(user.getSelectedProject().getId(), user.getSelectedApplication().getId()); + if(user.getRole().name().equals("ADMIN")){ + scenAccept.add("DRAFT"); + } + lstScenarios = scenarioRepo.findByVisibleInAndCategoryAndProjectOrApplication(user.getSelectedProject().getId(), user.getSelectedApplication().getId(), scenAccept); lstScenarios.add(scenarioDefault); }catch(Exception e){ logger.error("Exception ScenarioRepository:", e.getMessage()); @@ -85,4 +124,20 @@ public class ScenarioService { return lstScenarios; } + + public String updateRating(String id, String rating) { + logger.info("updateRating function:"); + String result = "KO"; + try { + Optional o_scenarioExecution = scenarioExecutionRepository.findById(id); + if (o_scenarioExecution.isPresent()) { + o_scenarioExecution.get().setRating(rating); + scenarioExecutionRepository.save(o_scenarioExecution.get()); + result = "OK"; + } + } catch (Exception e) { + logger.error("Exception in updateRating: {}", e.getMessage()); + } + return result; + } } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java index 88c2007..2e6acb6 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/DeleteDocTempSolver.java @@ -38,7 +38,7 @@ public class DeleteDocTempSolver extends StepSolver { if(this.step.getAttributes().containsKey("rag_topk")){ this.topk = (int) this.step.getAttributes().get("rag_topk"); }else{ - this.topk = 1000; + this.topk = 5; } if(this.step.getAttributes().containsKey("rag_threshold")){ diff --git a/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java index 644aaa2..ad8dff7 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java @@ -44,7 +44,7 @@ public class ExternalCodeGenieSolver extends StepSolver { @Autowired private ScenarioExecutionRepository scenarioExecutionRepo; - Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class); + Logger logger = (Logger) LoggerFactory.getLogger(ExternalCodeGenieSolver.class); private void loadParameters() { logger.info("Loading parameters"); @@ -75,12 +75,12 @@ public class ExternalCodeGenieSolver extends StepSolver { if (this.step.getAttributes().get("codegenie_username") != null) { this.codegenie_username = (String) this.step.getAttributes().get("codegenie_username"); - + } if (this.step.getAttributes().get("codegenie_password") != null) { this.codegenie_password = (String) this.step.getAttributes().get("codegenie_password"); - + } if (this.step.getAttributes().get("codegenie_output_variable") != null) { @@ -104,217 +104,218 @@ public class ExternalCodeGenieSolver extends StepSolver { @Override public ScenarioExecution solveStep() throws Exception { - try { + System.out.println("Solving step: " + this.step.getName()); - System.out.println("Solving step: " + this.step.getName()); + this.scenarioExecution.setCurrentStepId(this.step.getStepId()); - this.scenarioExecution.setCurrentStepId(this.step.getStepId()); + loadParameters(); - loadParameters(); + // token + HttpHeaders headersToken = new HttpHeaders(); + headersToken.setContentType(MediaType.APPLICATION_JSON); + // headers.set("Authorization", "Bearer "); + String auth = codegenie_username + ":" + codegenie_password; + String encodedAuth = java.util.Base64.getEncoder().encodeToString(auth.getBytes()); + headersToken.set("Authorization", "Basic " + encodedAuth); - // token - HttpHeaders headersToken = new HttpHeaders(); - headersToken.setContentType(MediaType.APPLICATION_JSON); - // headers.set("Authorization", "Bearer "); - String auth = codegenie_username + ":" + codegenie_password; - String encodedAuth = java.util.Base64.getEncoder().encodeToString(auth.getBytes()); - headersToken.set("Authorization", "Basic " + encodedAuth); + HttpEntity entityToken = new HttpEntity<>(headersToken); - HttpEntity entityToken = new HttpEntity<>(headersToken); + RestTemplate restTemplate = new RestTemplate(); - RestTemplate restTemplate = new RestTemplate(); + ResponseEntity response = restTemplate.exchange( + this.codegenie_base_url + "/token", + HttpMethod.POST, + entityToken, + String.class); + JSONObject jsonResponse = new JSONObject(response.getBody()); - ResponseEntity response = restTemplate.exchange( + List output_type = new ArrayList(); + output_type.add(this.codegenie_output_type); + + JSONObject requestBody = new JSONObject(); + requestBody.put("execution_id", this.scenarioExecution.getId()); + requestBody.put("docs_zip_file", this.codegenie_input); + requestBody.put("application", this.codegenie_application); + requestBody.put("project", this.codegenie_project); + requestBody.put("output_type", output_type); + requestBody.put("output_variable", this.codegenie_output_variable); + requestBody.put("model_provider", this.codegenie_model_provider); + + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + headers.set("Authorization", + jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); + HttpEntity request = new HttpEntity<>(requestBody.toString(), headers); + + response = restTemplate.exchange( + this.codegenie_base_url + "/execute", + HttpMethod.POST, + request, + String.class); + + jsonResponse = new JSONObject(response.getBody()); + + /* + * response = restTemplate.exchange( + * //this.codegenie_base_url + "/execution_status/" + + * this.scenarioExecution.getId(), + * this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17", + * HttpMethod.GET, + * request, + * String.class); + * jsonResponse = new JSONObject(response.getBody()); + */ + + int maxTries = 500; + // Pool the status GET api until it return the SUCCESS or FAILED message + while (!jsonResponse.get("status").equals("DONE") && !jsonResponse.get("status").equals("ERROR") + && maxTries > 0) { + + response = restTemplate.exchange( this.codegenie_base_url + "/token", HttpMethod.POST, entityToken, String.class); - JSONObject jsonResponse = new JSONObject(response.getBody()); + jsonResponse = new JSONObject(response.getBody()); - List output_type = new ArrayList(); - output_type.add(this.codegenie_output_type); - - JSONObject requestBody = new JSONObject(); - requestBody.put("execution_id", this.scenarioExecution.getId()); - requestBody.put("docs_zip_file", this.codegenie_input); - requestBody.put("application", this.codegenie_application); - requestBody.put("project", this.codegenie_project); - requestBody.put("output_type", output_type); - requestBody.put("output_variable", this.codegenie_output_variable); - requestBody.put("model_provider", this.codegenie_model_provider); - - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); headers.set("Authorization", jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); - HttpEntity request = new HttpEntity<>(requestBody.toString(), headers); - - - response = restTemplate.exchange( - this.codegenie_base_url + "/execute", - HttpMethod.POST, - request, - String.class); - - jsonResponse = new JSONObject(response.getBody()); - - /* response = restTemplate.exchange( - //this.codegenie_base_url + "/execution_status/" + this.scenarioExecution.getId(), - this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17", - HttpMethod.GET, - request, - String.class); - jsonResponse = new JSONObject(response.getBody());*/ - - int maxTries = 500; - // Pool the status GET api until it return the SUCCESS or FAILED message - while (!jsonResponse.get("status").equals("DONE") && !jsonResponse.get("status").equals("ERROR") - && maxTries > 0) { - - response = restTemplate.exchange( - this.codegenie_base_url + "/token", - HttpMethod.POST, - entityToken, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - headers.set("Authorization", - jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); - request = new HttpEntity<>(requestBody.toString(), headers); - - response = restTemplate.exchange( - this.codegenie_base_url + "/execution_status/" + this.scenarioExecution.getId(), - //this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17", - HttpMethod.GET, - request, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - try { - Thread.sleep(120000); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - logger.error("Thread was interrupted", e); - } - logger.info("Check status => Remaining tryes :" + maxTries); - logger.info("Percent: " + jsonResponse.get("percent")); - maxTries--; - - } - - logger.info("Stop pooling codegenies pod. Latest status = " + jsonResponse.get("status")); - - if (jsonResponse.get("status").equals("DONE")) { - - response = restTemplate.exchange( - this.codegenie_base_url + "/token", - HttpMethod.POST, - entityToken, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - headers.set("Authorization", - jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); - request = new HttpEntity<>(requestBody.toString(), headers); - - response = restTemplate.exchange( - this.codegenie_base_url + "/execution_result/" +this.scenarioExecution.getId(), - //this.codegenie_base_url + "/execution_result/679752f85189eb5621b48e17", - HttpMethod.GET, - request, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - if (jsonResponse.get("status").equals("DONE")) { - - // Accedi all'oggetto "outputs" - JSONObject outputs = jsonResponse.getJSONObject("outputs"); - - // Ottieni il valore del campo "json" - String jsonPath = ""; - - if (this.codegenie_output_type.equals("FILE")) { - - jsonPath = outputs.getString("file"); - - } else if (this.codegenie_output_type.equals("MARKDOWN")) { - jsonPath = outputs.getString("markdown"); - // this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, - // jsonResponse.get("OUTPUT").toString()); - } else if (this.codegenie_output_type.equals("JSON")) { - jsonPath = outputs.getString("json"); - // this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, - // jsonResponse.get("OUTPUT").toString()); - } - - response = restTemplate.exchange( - this.codegenie_base_url + "/token", - HttpMethod.POST, - entityToken, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - headers.set("Authorization", - jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); - request = new HttpEntity<>(requestBody.toString(), headers); - - ResponseEntity responseFile = restTemplate.exchange( - this.codegenie_base_url + "/download/" + jsonPath, - HttpMethod.GET, - request, - byte[].class); - - // Salva i bytes in un file - byte[] fileBytes = responseFile.getBody(); - String fileBase64 = Base64.getEncoder().encodeToString(fileBytes); - - this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, - fileBase64); - this.scenarioExecution.setNextStepId(this.step.getNextStepId()); - this.scenarioExecution.getExecSharedMap().put("status", - "DONE"); - } else { - throw new Exception( - "codegenie execution failed with status: " + jsonResponse.get("status").toString()); - } - - } else { - logger.error("ERROR on pooling codegenies"); - - throw new Exception("codegenie execution failed with status: " + jsonResponse.get("status").toString()); - } + request = new HttpEntity<>(requestBody.toString(), headers); response = restTemplate.exchange( - this.codegenie_base_url + "/token", - HttpMethod.POST, - entityToken, - String.class); - jsonResponse = new JSONObject(response.getBody()); - - headers.set("Authorization", - jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); - request = new HttpEntity<>(requestBody.toString(), headers); - - response = restTemplate.exchange( - this.codegenie_base_url + "/execution_consumption/" + this.scenarioExecution.getId(), - //this.codegenie_base_url + "/execution_consumption/679752f85189eb5621b48e17", + this.codegenie_base_url + "/execution_status/" + this.scenarioExecution.getId(), + // this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17", HttpMethod.GET, request, String.class); jsonResponse = new JSONObject(response.getBody()); - this.scenarioExecution.getExecSharedMap().put("input_token", jsonResponse.get("input_token").toString()); - this.scenarioExecution.getExecSharedMap().put("output_token", jsonResponse.get("output_token").toString()); - this.scenarioExecution.getExecSharedMap().put("total_token", jsonResponse.get("total_token").toString()); - this.scenarioExecution.getExecSharedMap().put("used_model", jsonResponse.get("used_model").toString()); - // scenarioExecutionRepo.save(this.scenarioExecution); + try { + Thread.sleep(120000); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + logger.error("Thread was interrupted", e); + } + logger.info("Check status => Remaining tryes :" + maxTries); + logger.info("Percent: " + jsonResponse.get("percent")); + maxTries--; - } catch (Exception e) { - logger.error("Error in codegenie execution: " + e.getMessage(), e); - this.scenarioExecution.getExecSharedMap().put("status", - "ERROR"); } + logger.info("Stop pooling codegenies pod. Latest status = " + jsonResponse.get("status")); + + if (jsonResponse.get("status").equals("DONE")) { + + response = restTemplate.exchange( + this.codegenie_base_url + "/token", + HttpMethod.POST, + entityToken, + String.class); + jsonResponse = new JSONObject(response.getBody()); + + headers.set("Authorization", + jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); + request = new HttpEntity<>(requestBody.toString(), headers); + + response = restTemplate.exchange( + this.codegenie_base_url + "/execution_result/" + this.scenarioExecution.getId(), + // this.codegenie_base_url + "/execution_result/679752f85189eb5621b48e17", + HttpMethod.GET, + request, + String.class); + jsonResponse = new JSONObject(response.getBody()); + + if (jsonResponse.get("status").equals("DONE")) { + + // Accedi all'oggetto "outputs" + JSONObject outputs = jsonResponse.getJSONObject("outputs"); + + // Ottieni il valore del campo "json" + String jsonPath = ""; + + if (this.codegenie_output_type.equals("FILE")) { + + jsonPath = outputs.getString("file"); + + } else if (this.codegenie_output_type.equals("MARKDOWN")) { + jsonPath = outputs.getString("markdown"); + // this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, + // jsonResponse.get("OUTPUT").toString()); + } else if (this.codegenie_output_type.equals("JSON")) { + jsonPath = outputs.getString("json"); + // this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, + // jsonResponse.get("OUTPUT").toString()); + } + + response = restTemplate.exchange( + this.codegenie_base_url + "/token", + HttpMethod.POST, + entityToken, + String.class); + jsonResponse = new JSONObject(response.getBody()); + + headers.set("Authorization", + jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); + request = new HttpEntity<>(requestBody.toString(), headers); + + ResponseEntity responseFile = restTemplate.exchange( + this.codegenie_base_url + "/download/" + jsonPath, + HttpMethod.GET, + request, + byte[].class); + + // Salva i bytes in un file + byte[] fileBytes = responseFile.getBody(); + String fileBase64 = Base64.getEncoder().encodeToString(fileBytes); + + this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, + fileBase64); + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + this.scenarioExecution.getExecSharedMap().put("status", + "DONE"); + } else { + + this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, + "Error while generating file"); + throw new Exception( + "codegenie execution failed with status: " + jsonResponse.get("status").toString()); + } + + } else { + logger.error("ERROR on pooling codegenies"); + this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable, + "Error while generating file"); + + throw new Exception("codegenie execution failed with status: " + jsonResponse.get("status").toString()); + } + + response = restTemplate.exchange( + this.codegenie_base_url + "/token", + HttpMethod.POST, + entityToken, + String.class); + jsonResponse = new JSONObject(response.getBody()); + + headers.set("Authorization", + jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString()); + request = new HttpEntity<>(requestBody.toString(), headers); + + response = restTemplate.exchange( + this.codegenie_base_url + "/execution_consumption/" + this.scenarioExecution.getId(), + // this.codegenie_base_url + "/execution_consumption/679752f85189eb5621b48e17", + HttpMethod.GET, + request, + String.class); + jsonResponse = new JSONObject(response.getBody()); + this.scenarioExecution.getExecSharedMap().put("input_token", jsonResponse.get("input_token").toString()); + this.scenarioExecution.getExecSharedMap().put("output_token", jsonResponse.get("output_token").toString()); + this.scenarioExecution.getExecSharedMap().put("total_token", jsonResponse.get("total_token").toString()); + this.scenarioExecution.getExecSharedMap().put("used_model", jsonResponse.get("used_model").toString()); + + // scenarioExecutionRepo.save(this.scenarioExecution); + + this.scenarioExecution.setNextStepId(this.step.getNextStepId()); + return this.scenarioExecution; } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java index 0df48ef..478a383 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/SummarizeDocSolver.java @@ -36,16 +36,18 @@ public class SummarizeDocSolver extends StepSolver { private String qai_system_prompt_template_chunk; private String qai_path_file; private String qai_output_variable; - private int chunk_size_token; + private int max_chunk_size_token; private String encoding_name; - private double charMax; private int tokenCount; private Optional encoding; - private int percent_summarize; - private Integer maxTokenChunk; + private Double percent_chunk_length; private String qai_custom_memory_id; private Integer max_output_token; - private String qai_system_prompt_template_minimum; + private Integer max_input_token_not_to_be_chunked; + private String qai_system_prompt_template_formatter; + private boolean isChunked = false; + private Double chunk_size_token_calc; + // private boolean qai_load_graph_schema=false; Logger logger = (Logger) LoggerFactory.getLogger(SummarizeDocSolver.class); @@ -64,17 +66,18 @@ public class SummarizeDocSolver extends StepSolver { this.qai_output_variable = (String) this.step.getAttributes().get("qai_output_variable"); this.encoding_name = (String) this.step.getAttributes().get("encoding_name"); - this.chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("chunk_size_token")); + this.max_chunk_size_token = Integer.parseInt((String) this.step.getAttributes().get("max_chunk_size_token")); - this.percent_summarize = Integer.parseInt((String) this.step.getAttributes().get("percent_summarize")); + this.percent_chunk_length = (Double) this.step.getAttributes().get("percent_chunk_length"); this.qai_custom_memory_id = (String) this.step.getAttributes().get("qai_custom_memory_id"); this.max_output_token = Integer.parseInt((String) this.step.getAttributes().get("max_output_token")); - this.qai_system_prompt_template_minimum = attributeParser - .parse((String) this.step.getAttributes().get("qai_system_prompt_template_minimum")); + this.max_input_token_not_to_be_chunked = Integer.parseInt((String) this.step.getAttributes().get("max_input_token_not_to_be_chunked")); + this.qai_system_prompt_template_formatter = attributeParser + .parse((String) this.step.getAttributes().get("qai_system_prompt_template_formatter")); } @Override @@ -96,35 +99,37 @@ public class SummarizeDocSolver extends StepSolver { EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); encoding = registry.getEncoding(this.encoding_name); - // Conta i token - // int tokenCount = encoding.get().encode(text).size(); tokenCount = encoding.get().countTokens(text); - int charCount = text.length(); + logger.info("token count input: " + tokenCount); - // Stima media caratteri per token - // double charPerToken = (double) charCount / tokenCount; + // Calcolo della dimensione del chunk in token + chunk_size_token_calc = (Double)((double)tokenCount * percent_chunk_length); + if(chunk_size_token_calc>max_chunk_size_token){ + chunk_size_token_calc = (double)max_chunk_size_token; - //Double output_char = (double) charCount * ((double) this.percent_summarize / 100.0); - Double min_output_token = (double) tokenCount * ((double) this.percent_summarize / 100.0); - String content = new String(""); - content = this.qai_system_prompt_template.replace("max_number_token", - max_output_token.toString()); - - if (min_output_token < max_output_token) { - content += this.qai_system_prompt_template_minimum.replace( - "min_number_token", - min_output_token.toString()); } - // **Fase di Summarization** - String summarizedText = summarize(text); // 🔹 Applica la funzione di riassunto - // String template = this.qai_system_prompt_template+" The output length should - // be of " + output_char + " characters"; - logger.info("template: " + content); + String summarizedText = text; + + if(tokenCount > max_input_token_not_to_be_chunked){ + summarizedText = summarize(text); + } + // Creazione dei messaggi per il modello AI Message userMessage = new UserMessage(summarizedText); - Message systemMessage = new SystemMessage(content); - logger.info("template: " + systemMessage.getText()); + Message systemMessage = null; + + int tokenCountSummary = encoding.get().countTokens(summarizedText); + + if(isChunked && tokenCountSummary < max_output_token){ + systemMessage = new SystemMessage(this.qai_system_prompt_template_formatter); + logger.info("template formatter: " + this.qai_system_prompt_template_formatter); + }else{ + //here + systemMessage = new SystemMessage(this.qai_system_prompt_template); + logger.info("template: " + this.qai_system_prompt_template); + } + CallResponseSpec resp = chatClient.prompt() .messages(userMessage, systemMessage) .advisors(advisor -> advisor @@ -144,6 +149,9 @@ public class SummarizeDocSolver extends StepSolver { logger.info("Token usage information is not available."); } + tokenCount = encoding.get().countTokens(output); + logger.info("token count output: " + tokenCount); + // Salvataggio dell'output nel contesto di esecuzione this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); this.scenarioExecution.setNextStepId(this.step.getNextStepId()); @@ -157,14 +165,15 @@ public class SummarizeDocSolver extends StepSolver { } private String summarize(String text) { - // Se il testo è già corto, non riassumere logger.info("length: " + text.length()); Double chunk_size_text; + + tokenCount = encoding.get().countTokens(text); int textLengthPlus = (int) (text.length() * 1.1); int tokenCountPlus = (int) (tokenCount * 1.1); - // chunk_size_token/(ratio+10%) - // Double ratio = Math.floor((textLengthPlus / charMax) + 1); - Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); + + Double ratio = Math.floor((tokenCountPlus / chunk_size_token_calc) + 1); + //Double ratio = Math.floor((tokenCountPlus / chunk_size_token) + 1); if (ratio == 1) { return text; } else { @@ -175,21 +184,19 @@ public class SummarizeDocSolver extends StepSolver { List chunks = chunkText(text, chunk_size_text.intValue()); List summarizedChunks = new ArrayList<>(); - Double maxTokenChunkD = Math.ceil(chunk_size_token / ratio); - maxTokenChunk = maxTokenChunkD.intValue(); - // Riassumere ogni chunk singolarmente for (String chunk : chunks) { logger.info("chunk length: " + chunk.length()); summarizedChunks.add(summarizeChunk(chunk)); } + isChunked = true; // Unire i riassunti String summarizedText = String.join(" ", summarizedChunks); int tokenCountSummarizedText = encoding.get().countTokens(summarizedText); // Se il riassunto è ancora troppo lungo, applicare ricorsione - if (tokenCountSummarizedText > chunk_size_token) { + if (tokenCountSummarizedText > max_output_token) { return summarize(summarizedText); } else { return summarizedText; @@ -208,21 +215,12 @@ public class SummarizeDocSolver extends StepSolver { // Metodo di riassunto per singolo chunk private String summarizeChunk(String chunk) { - String content = new String(""); - - if (maxTokenChunk < max_output_token) { - content = this.qai_system_prompt_template_chunk.replace("max_number_token", - maxTokenChunk.toString()); - }else{ - content = this.qai_system_prompt_template_chunk.replace("max_number_token", - max_output_token.toString()); - } Message chunkMessage = new UserMessage(chunk); Message systemMessage = new SystemMessage( - content); + this.qai_system_prompt_template_chunk); - logger.info("template chunk: " + content); + logger.info("template chunk: " + this.qai_system_prompt_template_chunk); CallResponseSpec resp = chatClient.prompt() .messages(chunkMessage, systemMessage) diff --git a/src/main/resources/application.properties b/src/main/resources/application.properties index 0b94c03..9955db8 100644 --- a/src/main/resources/application.properties +++ b/src/main/resources/application.properties @@ -67,7 +67,7 @@ spring.ai.vectorstore.chroma.collection-name=olympus spring.servlet.multipart.max-file-size=10MB spring.servlet.multipart.max-request-size=10MB -file.upload-dir=C:\\mnt\\hermione_storage\\documents\\file_input_scenarios\\ +file.upload-dir=/mnt/hermione_storage/documents/file_input_scenarios/