Merged PR 71: Clean Canvas and resolve bug

Clean Canvas and resolve bug
This commit is contained in:
2025-03-03 12:57:34 +00:00
4 changed files with 87 additions and 64 deletions

View File

@@ -19,5 +19,6 @@ public class AiModel {
private Integer maxTokens; private Integer maxTokens;
private String apiKey; private String apiKey;
private boolean isDefault = false; private boolean isDefault = false;
private boolean isCanvasDefault = false;
} }

View File

@@ -37,6 +37,6 @@ public class Scenario {
private boolean useChatMemory=false; private boolean useChatMemory=false;
private String outputType; private String outputType;
private boolean canvasEnabled=false;
} }

View File

@@ -1,13 +1,11 @@
package com.olympus.hermione.repository; package com.olympus.hermione.repository;
import org.springframework.stereotype.Repository; import org.springframework.stereotype.Repository;
import java.util.Optional;
import org.springframework.data.mongodb.repository.MongoRepository; import org.springframework.data.mongodb.repository.MongoRepository;
import com.olympus.hermione.models.AiModel; import com.olympus.hermione.models.AiModel;
@Repository @Repository
public interface AiModelRepository extends MongoRepository<AiModel, String>{ public interface AiModelRepository extends MongoRepository<AiModel, String>{
AiModel findByIsDefault(Boolean isDefault); AiModel findByIsDefault(Boolean isDefault);
AiModel findByIsCanvasDefault(Boolean isCanvasDefault);
} }

View File

@@ -1,91 +1,115 @@
package com.olympus.hermione.services; package com.olympus.hermione.services;
import java.util.List;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.SystemMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.client.AzureOpenApiChatClient;
import com.olympus.hermione.client.GoogleGeminiChatClient;
import com.olympus.hermione.client.OlympusChatClientResponse;
import com.olympus.hermione.dto.CanvasExecutionInput; import com.olympus.hermione.dto.CanvasExecutionInput;
import com.olympus.hermione.dto.CanvasOutput; import com.olympus.hermione.dto.CanvasOutput;
import com.olympus.hermione.models.AiModel;
import com.olympus.hermione.repository.AiModelRepository;
@Service @Service
public class CanvasExecutionService { public class CanvasExecutionService {
private final ChatModel chatModel;
@Autowired @Autowired
public CanvasExecutionService(@Qualifier("openAiChatModel") ChatModel chatModel) { private AiModelRepository aiModelRepository;
this.chatModel = chatModel;
} private OlympusChatClient olympusChatClient = null;
private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class); private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class);
public OlympusChatClient createCanvasClient(AiModelRepository aiModelRepository){
public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ AiModel aiModel = aiModelRepository.findByIsCanvasDefault(true);
CanvasOutput canvasOutput = new CanvasOutput(); if (aiModel == null) {
String input = canvasExecutionInput.getInput(); logger.error("No default AI model found");
if(input != null){ throw new IllegalArgumentException("No default AI model found");
Message userMessage = new UserMessage(input);
Message systemMessage = new SystemMessage("Please answer the question accurately.");
Prompt prompt = new Prompt(List.of(userMessage,systemMessage));
List<Generation> response = chatModel.call(prompt).getResults();
canvasOutput.setStringOutput(response.get(0).getOutput().getText());
return canvasOutput;
} else{
logger.error("Input is not correct");
canvasOutput.setStringOutput(null);
} }
return canvasOutput; if (aiModel.getApiProvider().equals("AzureOpenAI")) {
olympusChatClient = new AzureOpenApiChatClient();
}
if (aiModel.getApiProvider().equals("GoogleGemini")) {
olympusChatClient = new GoogleGeminiChatClient();
}
olympusChatClient.init(aiModel.getEndpoint(), aiModel.getApiKey(), aiModel.getMaxTokens());
return olympusChatClient;
}
public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){
try{
CanvasOutput canvasOutput = new CanvasOutput();
String input = canvasExecutionInput.getInput();
if(input != null){
olympusChatClient = createCanvasClient(aiModelRepository);
String userText = input;
String systemText = "Answer the following question using the language below using markdown: ";
OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText +"\n"+ userText);
canvasOutput.setStringOutput(resp.getContent());
return canvasOutput;
} else{
logger.error("Input is not correct");
canvasOutput.setStringOutput(null);
}
return canvasOutput;
} catch (Exception e){
logger.error("Error in askHermione: " + e.getMessage());
return null;
}
} }
public CanvasOutput rephraseText(CanvasExecutionInput canvasExecutionInput) { public CanvasOutput rephraseText(CanvasExecutionInput canvasExecutionInput) {
CanvasOutput canvasOutput = new CanvasOutput(); try{
String input = canvasExecutionInput.getInput(); CanvasOutput canvasOutput = new CanvasOutput();
if(input != null){ String input = canvasExecutionInput.getInput();
Message userMessage = new UserMessage(input); if(input != null){
Message systemMessage = new SystemMessage("Please rephrase the following text: "); olympusChatClient = createCanvasClient(aiModelRepository);
Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); String userText = input;
List<Generation> response = chatModel.call(prompt).getResults(); String systemText = "Rephrase the following text using the language below using markdown: ";
OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText);
canvasOutput.setStringOutput(response.get(0).getOutput().getText());
canvasOutput.setStringOutput(resp.getContent());
return canvasOutput;
} else{
logger.error("Input is not correct");
canvasOutput.setStringOutput(null);
}
return canvasOutput; return canvasOutput;
} else{ } catch (Exception e){
logger.error("Input is not correct"); logger.error("Error in rephraseText: " + e.getMessage());
canvasOutput.setStringOutput(null); return null;
} }
return canvasOutput;
} }
public CanvasOutput summarizeText(CanvasExecutionInput canvasExecutionInput) { public CanvasOutput summarizeText(CanvasExecutionInput canvasExecutionInput) {
CanvasOutput canvasOutput = new CanvasOutput(); try{
String input = canvasExecutionInput.getInput(); CanvasOutput canvasOutput = new CanvasOutput();
if(input != null){ String input = canvasExecutionInput.getInput();
Message userMessage = new UserMessage(input); if(input != null){
Message systemMessage = new SystemMessage("Please summarize the following text: "); olympusChatClient = createCanvasClient(aiModelRepository);
Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); String userText = input;
List<Generation> response = chatModel.call(prompt).getResults(); String systemText = "Summarize the following text using the language below using markdown: ";
OlympusChatClientResponse resp = olympusChatClient.getChatCompletion(systemText+"\n"+userText);
canvasOutput.setStringOutput(response.get(0).getOutput().getText());
canvasOutput.setStringOutput(resp.getContent());
return canvasOutput;
} else{
logger.error("Input is not correct");
canvasOutput.setStringOutput(null);
}
return canvasOutput; return canvasOutput;
} else{ } catch (Exception e){
logger.error("Input is not correct"); logger.error("Error in summarizeText: " + e.getMessage());
canvasOutput.setStringOutput(null); return null;
} }
return canvasOutput;
} }
} }