Aggiungi supporto per Google Gemini nel servizio di esecuzione degli scenari

This commit is contained in:
andrea.terzani
2025-02-20 19:12:48 +01:00
parent 170a2dcfc9
commit 255e65af5e
3 changed files with 122 additions and 2 deletions

View File

@@ -0,0 +1,100 @@
package com.olympus.hermione.client;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
public class GoogleGeminiChatClient implements OlympusChatClient {
private String apiUrl;
private String apiKey;
private int maxOutputTokens;
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
@Override
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
this.apiUrl = apiUrl;
this.apiKey = apiKey;
this.maxOutputTokens = maxOutputTokens;
}
@Override
public OlympusChatClientResponse getChatCompletion(String userInput) {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
Map<String, Object> userMessage = new HashMap<>();
userMessage.put("role", "user");
userMessage.put("parts", new Object[]{Map.of("text", userInput)});
Map<String, Object> conversation = new HashMap<>();
conversation.put("contents", new Object[]{userMessage});
Map<String, Object> generationConfig = new HashMap<>();
/* generationConfig.put("temperature", 1);
generationConfig.put("topK", 40);
generationConfig.put("topP", 0.95);*/
generationConfig.put("maxOutputTokens", maxOutputTokens);
generationConfig.put("responseMimeType", "text/plain");
conversation.put("generationConfig", generationConfig);
HttpEntity<Map<String, Object>> request = new HttpEntity<>(conversation, headers);
RestTemplate restTemplate = new RestTemplate();
String fullUrl = apiUrl + "?key=" + apiKey;
ResponseEntity<String> response = restTemplate.exchange(fullUrl, HttpMethod.POST, request, String.class);
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
try{
Map<String, Object> responseBody = new ObjectMapper().readValue(response.getBody(), Map.class);
List<Map<String, Object>> candidates = (List<Map<String, Object>>) responseBody.get("candidates");
Map<String, Object> firstCandidate = candidates.get(0);
Map<String, Object> content = (Map<String, Object>) firstCandidate.get("content");
List<Map<String, Object>> parts = (List<Map<String, Object>>) content.get("parts");
String text = (String) parts.get(0).get("text");
ObjectMapper objectMapper = new ObjectMapper();
JsonNode rootNode;
try {
rootNode = objectMapper.readTree(response.getBody());
} catch (JsonProcessingException e) {
logger.error(e.getMessage());
return null;
}
olympusChatClientResponse.setContent(text);
olympusChatClientResponse.setInputTokens(rootNode.path("usageMetadata").path("promptTokenCount").asInt());
olympusChatClientResponse.setOutputTokens(rootNode.path("usageMetadata").path("candidatesTokenCount").asInt());
olympusChatClientResponse.setTotalTokens(rootNode.path("usageMetadata").path("totalTokenCount").asInt());
}catch(Exception e){
logger.error("Error parsing response", e);
return null;
}
return olympusChatClientResponse;
}
}

View File

@@ -445,7 +445,22 @@ public class ScenarioExecutionService {
OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions);
logger.info("AI model used: " + aiModel.getModel());
return openaichatModel;
case "GoogleGemini":
OpenAIClient tempopenAIClient = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint())
.buildClient();
AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(aiModel.getModel())
.withTemperature(aiModel.getTemperature())
.withMaxTokens(aiModel.getMaxTokens())
.build();
AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions);
logger.info("AI model used: GoogleGemini");
return tempazureOpenaichatModel;
default:
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
}

View File

@@ -10,6 +10,7 @@ import org.springframework.ai.chat.metadata.Usage;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.olympus.hermione.client.AzureOpenApiChatClient;
import com.olympus.hermione.client.GoogleGeminiChatClient;
import com.olympus.hermione.client.OlympusChatClient;
import com.olympus.hermione.client.OlympusChatClientResponse;
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
@@ -53,7 +54,11 @@ public class OlynmpusChatClientSolver extends StepSolver{
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
chatClient = new AzureOpenApiChatClient();
} else {
}
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) {
chatClient = new GoogleGeminiChatClient();
}
else {
chatClient = null;
}