From 06d4989cc1ffdda09b4da2eaf62ca1b5507b611a Mon Sep 17 00:00:00 2001 From: "andrea.terzani" Date: Thu, 14 Nov 2024 09:53:05 +0100 Subject: [PATCH] Implement model-specific prompt handling in AdvancedAIPromptSolver --- .../stepSolvers/AdvancedAIPromptSolver.java | 39 +++++++++++++++---- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java index 4a005bf..92566f8 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java @@ -66,17 +66,40 @@ public class AdvancedAIPromptSolver extends StepSolver { String userText = this.qai_user_input; - Message userMessage = new UserMessage(userText); - Message systemMessage = new SystemMessage(this.qai_system_prompt_template); - CallResponseSpec resp = chatClient.prompt() - .messages(userMessage,systemMessage) - .advisors(advisor -> advisor - .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) - .param("chat_memory_response_size", 100)) - .call(); + CallResponseSpec resp=null; + + + if( chatModel.getDefaultOptions().getModel().contains("o1-") ){ + logger.info("Using o1- model => " + chatModel.getDefaultOptions().getModel()); + + String completePrompt = this.qai_system_prompt_template + userText; + + resp = chatClient.prompt() + .messages(new UserMessage(completePrompt)) + .advisors(advisor -> advisor + .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) + .param("chat_memory_response_size", 100)) + .call(); + + }else{ + logger.info("Using standard model=> " + chatModel.getDefaultOptions().getModel()); + + Message userMessage = new UserMessage(userText); + Message systemMessage = new SystemMessage(this.qai_system_prompt_template); + + resp = chatClient.prompt() + .messages(userMessage,systemMessage) + .advisors(advisor -> advisor + .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) + .param("chat_memory_response_size", 100)) + .call(); + } + + Usage usage = resp.chatResponse().getMetadata().getUsage(); + if (usage != null) { Long usedTokens = usage.getTotalTokens(); this.scenarioExecution.setUsedTokens(usedTokens);