Aggiungi supporto per Google Gemini nel servizio di esecuzione degli scenari
This commit is contained in:
@@ -0,0 +1,100 @@
|
|||||||
|
package com.olympus.hermione.client;
|
||||||
|
|
||||||
|
import java.util.HashMap;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
|
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.http.HttpEntity;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpMethod;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
import org.springframework.web.client.RestTemplate;
|
||||||
|
|
||||||
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
|
import com.fasterxml.jackson.databind.JsonNode;
|
||||||
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
|
|
||||||
|
public class GoogleGeminiChatClient implements OlympusChatClient {
|
||||||
|
|
||||||
|
private String apiUrl;
|
||||||
|
private String apiKey;
|
||||||
|
private int maxOutputTokens;
|
||||||
|
|
||||||
|
Logger logger = LoggerFactory.getLogger(AzureOpenApiChatClient.class);
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void init(String apiUrl, String apiKey, int maxOutputTokens) {
|
||||||
|
this.apiUrl = apiUrl;
|
||||||
|
this.apiKey = apiKey;
|
||||||
|
this.maxOutputTokens = maxOutputTokens;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public OlympusChatClientResponse getChatCompletion(String userInput) {
|
||||||
|
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.setContentType(MediaType.APPLICATION_JSON);
|
||||||
|
|
||||||
|
Map<String, Object> userMessage = new HashMap<>();
|
||||||
|
userMessage.put("role", "user");
|
||||||
|
userMessage.put("parts", new Object[]{Map.of("text", userInput)});
|
||||||
|
|
||||||
|
|
||||||
|
Map<String, Object> conversation = new HashMap<>();
|
||||||
|
conversation.put("contents", new Object[]{userMessage});
|
||||||
|
|
||||||
|
Map<String, Object> generationConfig = new HashMap<>();
|
||||||
|
/* generationConfig.put("temperature", 1);
|
||||||
|
generationConfig.put("topK", 40);
|
||||||
|
generationConfig.put("topP", 0.95);*/
|
||||||
|
generationConfig.put("maxOutputTokens", maxOutputTokens);
|
||||||
|
generationConfig.put("responseMimeType", "text/plain");
|
||||||
|
|
||||||
|
conversation.put("generationConfig", generationConfig);
|
||||||
|
|
||||||
|
HttpEntity<Map<String, Object>> request = new HttpEntity<>(conversation, headers);
|
||||||
|
RestTemplate restTemplate = new RestTemplate();
|
||||||
|
|
||||||
|
String fullUrl = apiUrl + "?key=" + apiKey;
|
||||||
|
ResponseEntity<String> response = restTemplate.exchange(fullUrl, HttpMethod.POST, request, String.class);
|
||||||
|
OlympusChatClientResponse olympusChatClientResponse = new OlympusChatClientResponse();
|
||||||
|
|
||||||
|
|
||||||
|
try{
|
||||||
|
Map<String, Object> responseBody = new ObjectMapper().readValue(response.getBody(), Map.class);
|
||||||
|
List<Map<String, Object>> candidates = (List<Map<String, Object>>) responseBody.get("candidates");
|
||||||
|
Map<String, Object> firstCandidate = candidates.get(0);
|
||||||
|
Map<String, Object> content = (Map<String, Object>) firstCandidate.get("content");
|
||||||
|
List<Map<String, Object>> parts = (List<Map<String, Object>>) content.get("parts");
|
||||||
|
String text = (String) parts.get(0).get("text");
|
||||||
|
|
||||||
|
ObjectMapper objectMapper = new ObjectMapper();
|
||||||
|
JsonNode rootNode;
|
||||||
|
try {
|
||||||
|
rootNode = objectMapper.readTree(response.getBody());
|
||||||
|
} catch (JsonProcessingException e) {
|
||||||
|
logger.error(e.getMessage());
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
olympusChatClientResponse.setContent(text);
|
||||||
|
olympusChatClientResponse.setInputTokens(rootNode.path("usageMetadata").path("promptTokenCount").asInt());
|
||||||
|
olympusChatClientResponse.setOutputTokens(rootNode.path("usageMetadata").path("candidatesTokenCount").asInt());
|
||||||
|
olympusChatClientResponse.setTotalTokens(rootNode.path("usageMetadata").path("totalTokenCount").asInt());
|
||||||
|
}catch(Exception e){
|
||||||
|
logger.error("Error parsing response", e);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
return olympusChatClientResponse;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
}
|
||||||
@@ -446,6 +446,21 @@ public class ScenarioExecutionService {
|
|||||||
logger.info("AI model used: " + aiModel.getModel());
|
logger.info("AI model used: " + aiModel.getModel());
|
||||||
return openaichatModel;
|
return openaichatModel;
|
||||||
|
|
||||||
|
|
||||||
|
case "GoogleGemini":
|
||||||
|
OpenAIClient tempopenAIClient = new OpenAIClientBuilder()
|
||||||
|
.credential(new AzureKeyCredential(aiModel.getApiKey()))
|
||||||
|
.endpoint(aiModel.getEndpoint())
|
||||||
|
.buildClient();
|
||||||
|
AzureOpenAiChatOptions tempopenAIChatOptions = AzureOpenAiChatOptions.builder()
|
||||||
|
.withDeploymentName(aiModel.getModel())
|
||||||
|
.withTemperature(aiModel.getTemperature())
|
||||||
|
.withMaxTokens(aiModel.getMaxTokens())
|
||||||
|
.build();
|
||||||
|
|
||||||
|
AzureOpenAiChatModel tempazureOpenaichatModel = new AzureOpenAiChatModel(tempopenAIClient, tempopenAIChatOptions);
|
||||||
|
logger.info("AI model used: GoogleGemini");
|
||||||
|
return tempazureOpenaichatModel;
|
||||||
default:
|
default:
|
||||||
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
|
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import org.springframework.ai.chat.metadata.Usage;
|
|||||||
import com.fasterxml.jackson.core.JsonProcessingException;
|
import com.fasterxml.jackson.core.JsonProcessingException;
|
||||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||||
import com.olympus.hermione.client.AzureOpenApiChatClient;
|
import com.olympus.hermione.client.AzureOpenApiChatClient;
|
||||||
|
import com.olympus.hermione.client.GoogleGeminiChatClient;
|
||||||
import com.olympus.hermione.client.OlympusChatClient;
|
import com.olympus.hermione.client.OlympusChatClient;
|
||||||
import com.olympus.hermione.client.OlympusChatClientResponse;
|
import com.olympus.hermione.client.OlympusChatClientResponse;
|
||||||
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
import com.olympus.hermione.dto.aientity.CiaOutputEntity;
|
||||||
@@ -53,7 +54,11 @@ public class OlynmpusChatClientSolver extends StepSolver{
|
|||||||
|
|
||||||
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
|
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("AzureOpenAI")) {
|
||||||
chatClient = new AzureOpenApiChatClient();
|
chatClient = new AzureOpenApiChatClient();
|
||||||
} else {
|
}
|
||||||
|
if ( scenarioExecution.getScenario().getAiModel().getApiProvider().equals("GoogleGemini")) {
|
||||||
|
chatClient = new GoogleGeminiChatClient();
|
||||||
|
}
|
||||||
|
else {
|
||||||
chatClient = null;
|
chatClient = null;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user