diff --git a/src/main/java/com/olympus/hermione/controllers/CanvasController.java b/src/main/java/com/olympus/hermione/controllers/CanvasController.java index b88b9f2..8defbb1 100644 --- a/src/main/java/com/olympus/hermione/controllers/CanvasController.java +++ b/src/main/java/com/olympus/hermione/controllers/CanvasController.java @@ -1,6 +1,37 @@ package com.olympus.hermione.controllers; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.web.bind.annotation.CrossOrigin; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; +import org.springframework.web.bind.annotation.RestController; +import com.olympus.hermione.dto.CanvasExecutionInput; +import com.olympus.hermione.dto.CanvasOutput; +import com.olympus.hermione.services.CanvasExecutionService; + + +@RestController +@CrossOrigin public class CanvasController { + + @Autowired + CanvasExecutionService canvasExecutionService; + + + @PostMapping("/askHermione") + public CanvasOutput askHermione(@RequestBody CanvasExecutionInput canvasExecutionInput) { + return canvasExecutionService.askHermione(canvasExecutionInput); + } + + @PostMapping("rephrase") + public CanvasOutput rephraseText(@RequestBody CanvasExecutionInput canvasExecutionInput) { + return canvasExecutionService.rephraseText(canvasExecutionInput); + } + + @PostMapping("summarize") + public CanvasOutput summarizeText(@RequestBody CanvasExecutionInput canvasExecutionInput) { + return canvasExecutionService.summarizeText(canvasExecutionInput); + } } diff --git a/src/main/java/com/olympus/hermione/dto/CanvasExecutionInput.java b/src/main/java/com/olympus/hermione/dto/CanvasExecutionInput.java new file mode 100644 index 0000000..02d307e --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/CanvasExecutionInput.java @@ -0,0 +1,10 @@ +package com.olympus.hermione.dto; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class CanvasExecutionInput { + private String input; +} diff --git a/src/main/java/com/olympus/hermione/dto/CanvasOutput.java b/src/main/java/com/olympus/hermione/dto/CanvasOutput.java new file mode 100644 index 0000000..375cb8b --- /dev/null +++ b/src/main/java/com/olympus/hermione/dto/CanvasOutput.java @@ -0,0 +1,10 @@ +package com.olympus.hermione.dto; + +import lombok.Getter; +import lombok.Setter; + +@Getter +@Setter +public class CanvasOutput { + private String stringOutput; +} diff --git a/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java new file mode 100644 index 0000000..f886922 --- /dev/null +++ b/src/main/java/com/olympus/hermione/services/CanvasExecutionService.java @@ -0,0 +1,86 @@ +package com.olympus.hermione.services; + +import java.util.List; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.stereotype.Service; + +import com.olympus.hermione.dto.CanvasExecutionInput; +import com.olympus.hermione.dto.CanvasOutput; + + +@Service +public class CanvasExecutionService { + + @Autowired + ChatModel chatModel; + + private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class); + + + public CanvasOutput askHermione(CanvasExecutionInput canvasExecutionInput){ + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + Message userMessage = new UserMessage(input); + Message systemMessage = new SystemMessage("Please answer the question accurately."); + Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); + List response = chatModel.call(prompt).getResults(); + + canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } + return canvasOutput; + + + } + + public CanvasOutput rephraseText(CanvasExecutionInput canvasExecutionInput) { + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + Message userMessage = new UserMessage(input); + Message systemMessage = new SystemMessage("Please rephrase the following text: "); + Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); + List response = chatModel.call(prompt).getResults(); + + canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } + return canvasOutput; + } + + public CanvasOutput summarizeText(CanvasExecutionInput canvasExecutionInput) { + CanvasOutput canvasOutput = new CanvasOutput(); + String input = canvasExecutionInput.getInput(); + if(input != null){ + Message userMessage = new UserMessage(input); + Message systemMessage = new SystemMessage("Please summarize the following text: "); + Prompt prompt = new Prompt(List.of(userMessage,systemMessage)); + List response = chatModel.call(prompt).getResults(); + + canvasOutput.setStringOutput(response.get(0).getOutput().getContent()); + return canvasOutput; + } else{ + logger.error("Input is not correct"); + canvasOutput.setStringOutput(null); + } + return canvasOutput; + } + +} +