diff --git a/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java b/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java new file mode 100644 index 0000000..0082222 --- /dev/null +++ b/src/main/java/com/olympus/hermione/client/GoogleGeminiChatClient.java @@ -0,0 +1,100 @@ +package com.olympus.hermione.client; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.web.client.RestTemplate; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +public class GoogleGeminiChatClient implements OlympusChatClient { + + private String apiUrl; + private String apiKey; + private int maxOutputTokens; + + Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class); + + @Override + public void init(String apiUrl, String apiKey, int maxOutputTokens) { + this.apiUrl = apiUrl; + this.apiKey = apiKey; + this.maxOutputTokens = maxOutputTokens; + + } + + @Override + public OlympusChatClientResponse getChatCompletion(String userInput) { + + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.APPLICATION_JSON); + + Map userMessage = new HashMap<>(); + userMessage.put("role", "user"); + userMessage.put("parts", new Object[]{Map.of("text", userInput)}); + + + Map conversation = new HashMap<>(); + conversation.put("contents", new Object[]{userMessage}); + + Map generationConfig = new HashMap<>(); + /* generationConfig.put("temperature", 1); + generationConfig.put("topK", 40); + generationConfig.put("topP", 0.95);*/ + generationConfig.put("maxOutputTokens", maxOutputTokens); + generationConfig.put("responseMimeType", "text/plain"); + + conversation.put("generationConfig", generationConfig); + + HttpEntity> request = new HttpEntity<>(conversation, headers); + RestTemplate restTemplate = new RestTemplate(); + + String fullUrl = apiUrl + "?key=" + apiKey; + ResponseEntity response = restTemplate.exchange(fullUrl, HttpMethod.POST, request, String.class); + OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse(); + + + try{ + Map responseBody = new ObjectMapper().readValue(response.getBody(), Map.class); + List> candidates = (List>) responseBody.get("candidates"); + Map firstCandidate = candidates.get(0); + Map content = (Map) firstCandidate.get("content"); + List> parts = (List>) content.get("parts"); + String text = (String) parts.get(0).get("text"); + + ObjectMapper objectMapper = new ObjectMapper(); + JsonNode rootNode; + try { + rootNode = objectMapper.readTree(response.getBody()); + } catch (JsonProcessingException e) { + logger.error(e.getMessage()); + return null; + } + + olympusChatClientResponse.setContent(text); + olympusChatClientResponse.setInputTokens(rootNode.path("usageMetadata").path("promptTokenCount").asInt()); + olympusChatClientResponse.setOutputTokens(rootNode.path("usageMetadata").path("candidatesTokenCount").asInt()); + olympusChatClientResponse.setTotalTokens(rootNode.path("usageMetadata").path("totalTokenCount").asInt()); + }catch(Exception e){ + logger.error("Error parsing response", e); + return null; + } + + + return olympusChatClientResponse; + + } + + + +} diff --git a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java index 4bfcaef..412f409 100644 --- a/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java +++ b/src/main/java/com/olympus/hermione/services/ScenarioExecutionService.java @@ -445,7 +445,22 @@ public class ScenarioExecutionService { OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions); logger.info("AI model used: " + aiModel.getModel()); return openaichatModel; - + + + case "GoogleGemini": + OpenAIClient tempopenAIClient = new OpenAIClientBuilder() + .credential(new AzureKeyCredential(aiModel.getApiKey())) + .endpoint(aiModel.getEndpoint()) + .buildClient(); + AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder() + .withDeploymentName(aiModel.getModel()) + .withTemperature(aiModel.getTemperature()) + .withMaxTokens(aiModel.getMaxTokens()) + .build(); + + AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions); + logger.info("AI model used: GoogleGemini"); + return tempazureOpenaichatModel; default: throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName()); } diff --git a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java index 9d07732..b22d841 100644 --- a/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java +++ b/src/main/java/com/olympus/hermione/stepSolvers/OlynmpusChatClientSolver.java @@ -10,6 +10,7 @@ import org.springframework.ai.chat.metadata.Usage; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; import com.olympus.hermione.client.AzureOpenApiChatClient; +import com.olympus.hermione.client.GoogleGeminiChatClient; import com.olympus.hermione.client.OlympusChatClient; import com.olympus.hermione.client.OlympusChatClientResponse; import com.olympus.hermione.dto.aientity.CiaOutputEntity; @@ -53,7 +54,11 @@ public class OlynmpusChatClientSolver extends StepSolver{ if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) { chatClient = new AzureOpenApiChatClient(); - } else { + } + if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) { + chatClient = new GoogleGeminiChatClient(); + } + else { chatClient = null; }