Aggiungi supporto per Google Gemini nel servizio di esecuzione degli scenari
This commit is contained in:
@@ -0,0 +1,100 @@
|
||||
package com.olympus.hermione.client;
|
||||
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpEntity;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpMethod;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.web.client.RestTemplate;
|
||||
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
public class GoogleGeminiChatClient implements OlympusChatClient {
|
||||
|
||||
private String apiUrl;
|
||||
private String apiKey;
|
||||
private int maxOutputTokens;
|
||||
|
||||
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
|
||||
|
||||
@Override
|
||||
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
|
||||
this.apiUrl = apiUrl;
|
||||
this.apiKey = apiKey;
|
||||
this.maxOutputTokens = maxOutputTokens;
|
||||
|
||||
}
|
||||
|
||||
@Override
|
||||
public OlympusChatClientResponse getChatCompletion(String userInput) {
|
||||
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||
|
||||
Map<String, Object> userMessage = new HashMap<>();
|
||||
userMessage.put("role", "user");
|
||||
userMessage.put("parts", new Object[]{Map.of("text", userInput)});
|
||||
|
||||
|
||||
Map<String, Object> conversation = new HashMap<>();
|
||||
conversation.put("contents", new Object[]{userMessage});
|
||||
|
||||
Map<String, Object> generationConfig = new HashMap<>();
|
||||
/* generationConfig.put("temperature", 1);
|
||||
generationConfig.put("topK", 40);
|
||||
generationConfig.put("topP", 0.95);*/
|
||||
generationConfig.put("maxOutputTokens", maxOutputTokens);
|
||||
generationConfig.put("responseMimeType", "text/plain");
|
||||
|
||||
conversation.put("generationConfig", generationConfig);
|
||||
|
||||
HttpEntity<Map<String, Object>> request = new HttpEntity<>(conversation, headers);
|
||||
RestTemplate restTemplate = new RestTemplate();
|
||||
|
||||
String fullUrl = apiUrl + "?key=" + apiKey;
|
||||
ResponseEntity<String> response = restTemplate.exchange(fullUrl, HttpMethod.POST, request, String.class);
|
||||
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
|
||||
|
||||
|
||||
try{
|
||||
Map<String, Object> responseBody = new ObjectMapper().readValue(response.getBody(), Map.class);
|
||||
List<Map<String, Object>> candidates = (List<Map<String, Object>>) responseBody.get("candidates");
|
||||
Map<String, Object> firstCandidate = candidates.get(0);
|
||||
Map<String, Object> content = (Map<String, Object>) firstCandidate.get("content");
|
||||
List<Map<String, Object>> parts = (List<Map<String, Object>>) content.get("parts");
|
||||
String text = (String) parts.get(0).get("text");
|
||||
|
||||
ObjectMapper objectMapper = new ObjectMapper();
|
||||
JsonNode rootNode;
|
||||
try {
|
||||
rootNode = objectMapper.readTree(response.getBody());
|
||||
} catch (JsonProcessingException e) {
|
||||
logger.error(e.getMessage());
|
||||
return null;
|
||||
}
|
||||
|
||||
olympusChatClientResponse.setContent(text);
|
||||
olympusChatClientResponse.setInputTokens(rootNode.path("usageMetadata").path("promptTokenCount").asInt());
|
||||
olympusChatClientResponse.setOutputTokens(rootNode.path("usageMetadata").path("candidatesTokenCount").asInt());
|
||||
olympusChatClientResponse.setTotalTokens(rootNode.path("usageMetadata").path("totalTokenCount").asInt());
|
||||
}catch(Exception e){
|
||||
logger.error("Error parsing response", e);
|
||||
return null;
|
||||
}
|
||||
|
||||
|
||||
return olympusChatClientResponse;
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
}
|
||||
@@ -445,7 +445,22 @@ public class ScenarioExecutionService {
|
||||
OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions);
|
||||
logger.info("AI model used: " + aiModel.getModel());
|
||||
return openaichatModel;
|
||||
|
||||
|
||||
|
||||
case "GoogleGemini":
|
||||
OpenAIClient tempopenAIClient = new OpenAIClientBuilder()
|
||||
.credential(new AzureKeyCredential(aiModel.getApiKey()))
|
||||
.endpoint(aiModel.getEndpoint())
|
||||
.buildClient();
|
||||
AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder()
|
||||
.withDeploymentName(aiModel.getModel())
|
||||
.withTemperature(aiModel.getTemperature())
|
||||
.withMaxTokens(aiModel.getMaxTokens())
|
||||
.build();
|
||||
|
||||
AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions);
|
||||
logger.info("AI model used: GoogleGemini");
|
||||
return tempazureOpenaichatModel;
|
||||
default:
|
||||
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import org.springframework.ai.chat.metadata.Usage;
|
||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
import com.olympus.hermione.client.AzureOpenApiChatClient;
|
||||
import com.olympus.hermione.client.GoogleGeminiChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClient;
|
||||
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
||||
@@ -53,7 +54,11 @@ public class OlynmpusChatClientSolver extends StepSolver{
|
||||
|
||||
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
|
||||
chatClient = new AzureOpenApiChatClient();
|
||||
} else {
|
||||
}
|
||||
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) {
|
||||
chatClient = new GoogleGeminiChatClient();
|
||||
}
|
||||
else {
|
||||
chatClient = null;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user