diff --git a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java index cd6202c..066c054 100644 --- a/src/main/java/com/olympus/hermione/models/ScenarioExecution.java +++ b/src/main/java/com/olympus/hermione/models/ScenarioExecution.java @@ -37,4 +37,5 @@ public class ScenarioExecution { private Date startDate; private Date endDate; + private Long usedTokens; } diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index ba046f7..fe111be 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -174,17 +174,27 @@ public class ScenarioExecutionService { ScenarioStep startStep = steps.stream().filter(step -> step.getStepId().equals(startStepId)).findFirst().orElse(null); executeScenarioStep(startStep, scenarioExecution); + Long usedTokens = (scenarioExecution.getUsedTokens() != null) ? scenarioExecution.getUsedTokens() : Long.valueOf(0); + + scenarioExecution.setUsedTokens(usedTokens); + scenarioExecutionRepository.save(scenarioExecution); while (scenarioExecution.getNextStepId()!=null) { ScenarioStep step = steps.stream().filter(s -> s.getStepId().equals(scenarioExecution.getNextStepId())).findFirst().orElse(null); executeScenarioStep(step, scenarioExecution); + if (scenarioExecution.getUsedTokens() != null && scenarioExecution.getUsedTokens() != 0) { + usedTokens += scenarioExecution.getUsedTokens(); + scenarioExecution.setUsedTokens(usedTokens); + scenarioExecutionRepository.save(scenarioExecution); + } + if(scenarioExecution.getLatestStepStatus() != null && scenarioExecution.getLatestStepStatus().equals("ERROR")){ logger.error("Error while executing step: " + step.getStepId()); break; } } - + scenarioExecution.setEndDate(new java.util.Date()); scenarioExecutionRepository.save(scenarioExecution); } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java index 8abbdc2..4a005bf 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/AdvancedAIPromptSolver.java @@ -1,10 +1,14 @@ package com.olympus.hermione.stepSolvers; import org.slf4j.LoggerFactory; +import org.springframework.ai.azure.openai.metadata.AzureOpenAiUsage; import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; +import org.springframework.ai.chat.model.ChatResponse; + import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.olympus.hermione.dto.aientity.CiaOutputEntity; @@ -71,7 +75,14 @@ public class AdvancedAIPromptSolver extends StepSolver { .param("chat_memory_conversation_id", this.scenarioExecution.getId()+this.qai_custom_memory_id) .param("chat_memory_response_size", 100)) .call(); - + + Usage usage = resp.chatResponse().getMetadata().getUsage(); + if (usage != null) { + Long usedTokens = usage.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + } else { + logger.info("Token usage information is not available."); + } if(qai_output_entityType!=null && qai_output_entityType.equals("CiaOutputEntity")){ logger.info("Output is of type CiaOutputEntity"); diff --git a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java index 9f7e805..5ab8ad2 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/BasicAIPromptSolver.java @@ -6,6 +6,7 @@ import org.springframework.ai.chat.client.ChatClient.CallResponseSpec; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.SystemMessage; import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.metadata.Usage; import org.springframework.ai.chat.model.Generation; import org.springframework.ai.chat.prompt.Prompt; @@ -60,6 +61,14 @@ public class BasicAIPromptSolver extends StepSolver { .messages(userMessage,systemMessage) .call(); + Usage usage = resp.chatResponse().getMetadata().getUsage(); + if (usage != null) { + Long usedTokens = usage.getTotalTokens(); + this.scenarioExecution.setUsedTokens(usedTokens); + } else { + logger.info("Token usage information is not available."); + } + String output = resp.content(); this.scenarioExecution.getExecSharedMap().put(this.qai_output_variable, output);