From 202bb8a5e1e33030130d1a3f8d02af2e6cc81e97 Mon Sep 17 00:00:00 2001 From: Emanuele Ferrelli Date: Fri, 15 Nov 2024 16:45:45 +0100 Subject: [PATCH] fixed usedTokens count --- .../services/ScenarioExecutionService.java | 8 +++----- .../stepSolvers/AdvancedAIPromptSolver.java | 18 ++++++++---------- .../stepSolvers/BasicAIPromptSolver.java | 3 ++- 3 files changed, 13 insertions(+), 16 deletions(-) diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 9761855..a3240de 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -177,17 +177,14 @@ public class ScenarioExecutionService { executeScenarioStep(startStep, scenarioExecution); Long usedTokens = (scenarioExecution.getUsedTokens() != null) ? scenarioExecution.getUsedTokens() : Long.valueOf(0); - scenarioExecution.setUsedTokens(usedTokens); - scenarioExecutionRepository.save(scenarioExecution); - while (scenarioExecution.getNextStepId()!=null) { ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); executeScenarioStep(step, scenarioExecution); if (scenarioExecution.getUsedTokens() != null && scenarioExecution.getUsedTokens() != 0) { usedTokens += scenarioExecution.getUsedTokens(); - scenarioExecution.setUsedTokens(usedTokens); - scenarioExecutionRepository.save(scenarioExecution); + + scenarioExecution.setUsedTokens(Long.valueOf(0)); //resetting value for next step if is not an AI step } if(scenarioExecution.getLatestStepStatus() != null && scenarioExecution.getLatestStepStatus().equals("ERROR")){ @@ -196,6 +193,7 @@ public class ScenarioExecutionService { } } + scenarioExecution.setUsedTokens(usedTokens); scenarioExecution.setEndDate(new java.util.Date()); scenarioExecutionRepository.save(scenarioExecution); } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java index 7e9dcf5..d3d9d70 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java @@ -72,16 +72,6 @@ public class AdvancedAIPromptSolver extends StepSolver { .param("chat_memory_response_size", 100)) .call(); - - /*Usage usage = resp.chatResponse().getMetadata().getUsage(); - - if (usage != null) { - Long usedTokens = usage.getTotalTokens(); - this.scenarioExecution.setUsedTokens(usedTokens); - } else { - logger.info("Token usage information is not available."); - }*/ - if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){ logger.info("Output is of type CiaOutputEntity"); @@ -99,6 +89,14 @@ public class AdvancedAIPromptSolver extends StepSolver { String output = resp.content(); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); } + + Usage usage = resp.chatResponse().getMetadata().getUsage(); + if (usage != null) { + Long usedTokens = usage.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + } else { + logger.info("Token usage information is not available."); + } this.scenarioExecution.setNextStepId(this.step.getNextStepId()); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index 5ab8ad2..65874b0 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -61,6 +61,8 @@ public class BasicAIPromptSolver extends StepSolver { .messages(userMessage,systemMessage) .call(); + String output = resp.content(); + Usage usage = resp.chatResponse().getMetadata().getUsage(); if (usage != null) { Long usedTokens = usage.getTotalTokens(); @@ -69,7 +71,6 @@ public class BasicAIPromptSolver extends StepSolver { logger.info("Token usage information is not available."); } - String output = resp.content(); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output); this.scenarioExecution.setNextStepId(this.step.getNextStepId());