diff --git a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java index 2d3225f..581a62f 100644 --- a/src/main/java/com/olympus/hermione/controllers/ScenarioController.java +++ b/src/main/java/com/olympus/hermione/controllers/ScenarioController.java @@ -99,22 +99,29 @@ public class ScenarioController { @GetMapping("/scenarios/getExecutionProgress/{id}") public ScenarioOutput getExecutionProgress(@PathVariable String id) { return scenarioExecutionService.getExecutionProgress(id); - } - + } @GetMapping("/scenarios/execute/{id}") public ScenarioExecution getScenarioExecution(@PathVariable String id) { - return scenarioExecutionRepository.findById(id).get(); + ScenarioExecution scenarioExecution = scenarioExecutionRepository.findById(id).get(); + String apiKey = scenarioExecution.getScenario().getAiModel().getApiKey(); + String endpoint = scenarioExecution.getScenario().getAiModel().getEndpoint(); + List availableForProject = scenarioExecution.getScenario().getAvailableForProjects(); + List availableForApplication = scenarioExecution.getScenario().getAvailableForApplications(); + if (apiKey != null) { + scenarioExecution.getScenario().getAiModel().setApiKey("**********"); } - - @GetMapping("/executions/based") - public Page getScenarioByUser( @RequestParam(defaultValue = "0") int page, - @RequestParam(defaultValue = "10") int size, - @RequestParam(required = false) String scenarioName, - @RequestParam(required = false) String executedBy, - @RequestParam(required = false) @DateTimeFormat(iso = DateTimeFormat.ISO.DATE) LocalDate startDate) { - return scenarioExecutionService.getListExecutionScenarioOptional(page, size, scenarioName, executedBy, startDate); + if (endpoint != null) { + scenarioExecution.getScenario().getAiModel().setEndpoint("**********"); } + if (availableForProject != null) { + scenarioExecution.getScenario().setAvailableForProjects(null); + } + if (availableForApplication != null) { + scenarioExecution.getScenario().setAvailableForApplications(null); + } + return scenarioExecution; +} @PostMapping("/executions") public Page getAdvancedExecutions( diff --git a/src/main/java/com/olympus/hermione/models/AiModel.java b/src/main/java/com/olympus/hermione/models/AiModel.java index aefeb7b..0dff9f1 100644 --- a/src/main/java/com/olympus/hermione/models/AiModel.java +++ b/src/main/java/com/olympus/hermione/models/AiModel.java @@ -19,5 +19,6 @@ public class AiModel { private Integer maxTokens; private String apiKey; private boolean isDefault = false; + private boolean isCanvasDefault = false; } diff --git a/src/main/java/com/olympus/hermione/models/Scenario.java b/src/main/java/com/olympus/hermione/models/Scenario.java index 51c8709..7d2b7a6 100644 --- a/src/main/java/com/olympus/hermione/models/Scenario.java +++ b/src/main/java/com/olympus/hermione/models/Scenario.java @@ -37,6 +37,6 @@ public class Scenario { private boolean useChatMemory=false; private String outputType; - + private boolean canvasEnabled=false; } diff --git a/src/main/java/com/olympus/hermione/repository/AiModelRepository.java b/src/main/java/com/olympus/hermione/repository/AiModelRepository.java index b22ef14..e4333ad 100644 --- a/src/main/java/com/olympus/hermione/repository/AiModelRepository.java +++ b/src/main/java/com/olympus/hermione/repository/AiModelRepository.java @@ -1,13 +1,11 @@ package com.olympus.hermione.repository; import org.springframework.stereotype.Repository; - -import java.util.Optional; - import org.springframework.data.mongodb.repository.MongoRepository; - import com.olympus.hermione.models.AiModel; +import com.olympus.hermione.models.AiModel; @Repository public interface AiModelRepository extends MongoRepository{ AiModel findByIsDefault(Boolean isDefault); + AiModel findByIsCanvasDefault(Boolean isCanvasDefault); } diff --git a/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java index ddd3732..17a5141 100644 --- a/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java @@ -1,91 +1,115 @@ package com.olympus.hermione.services; -import java.util.List; - import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.springframework.ai.chat.messages.Message; -import org.springframework.ai.chat.messages.SystemMessage; -import org.springframework.ai.chat.messages.UserMessage; -import org.springframework.ai.chat.model.ChatModel; -import org.springframework.ai.chat.model.Generation; -import org.springframework.ai.chat.prompt.Prompt; import org.springframework.beans.factory.annotation.Autowired; -import org.springframework.beans.factory.annotation.Qualifier; import org.springframework.stereotype.Service; - +import com.olympus.hermione.client.OlympusChatClient; +import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.GoogleGeminiChatClient; +import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.dto.CanvasExecutionInput; import com.olympus.hermione.dto.CanvasOutput; +import com.olympus.hermione.models.AiModel; +import com.olympus.hermione.repository.AiModelRepository; @Service public class CanvasExecutionService { - - private final ChatModel chatModel; - @Autowired - public CanvasExecutionService(@Qualifier("openAiChatModel") ChatModel chatModel) { - this.chatModel = chatModel; - } + private AiModelRepository aiModelRepository; + + private OlympusChatClient olympusChatClient = null; private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class); - - public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please answer the question accurately."); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); - return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + public OlympusChatClient createCanvasClient(AiModelRepository aiModelRepository){ + AiModel aiModel = aiModelRepository.findByIsCanvasDefault(true); + if (aiModel == null) { + logger.error("No default AI model found"); + throw new IllegalArgumentException("No default AI model found"); } - return canvasOutput; - + if (aiModel.getApiProvider().equals("AzureOpenAI")) { + olympusChatClient = new AzureOpenApiChatClient(); + } + if (aiModel.getApiProvider().equals("GoogleGemini")) { + olympusChatClient = new GoogleGeminiChatClient(); + } + olympusChatClient.init(aiModel.getEndpoint(), aiModel.getApiKey(), aiModel.getMaxTokens()); + + return olympusChatClient; + } + + public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Answer the following question: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText +"\n"+ userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } + return canvasOutput; + } catch (Exception e){ + logger.error("Error in askHermione: " + e.getMessage()); + return null; + } } public CanvasOutput rephraseText(CanvasExecutionInput canvasExecutionInput) { - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please rephrase the following text: "); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Rephrase the following text: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + } catch (Exception e){ + logger.error("Error in rephraseText: " + e.getMessage()); + return null; } - return canvasOutput; } + public CanvasOutput summarizeText(CanvasExecutionInput canvasExecutionInput) { - CanvasOutput canvasOutput = new CanvasOutput(); - String input = canvasExecutionInput.getInput(); - if(input != null){ - Message userMessage = new UserMessage(input); - Message systemMessage = new SystemMessage("Please summarize the following text: "); - Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); - List response = chatModel.call(prompt).getResults(); - - canvasOutput.setStringOutput(response.get(0).getOutput().getText()); + try{ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + olympusChatClient = createCanvasClient(aiModelRepository); + String userText = input; + String systemText = "Summarize the following text: "; + OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText); + + canvasOutput.setStringOutput(resp.getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } return canvasOutput; - } else{ - logger.error("Input is not correct"); - canvasOutput.setStringOutput(null); + } catch (Exception e){ + logger.error("Error in summarizeText: " + e.getMessage()); + return null; } - return canvasOutput; } }