Files
hermione/src/main/java/com/olympus/hermione/stepSolvers/ExternalCodeGenieSolver.java
2025-01-29 11:04:38 +01:00

347 lines
15 KiB
Java

package com.olympus.hermione.stepSolvers;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Base64;
import java.util.List;
import java.util.zip.ZipEntry;
import java.util.zip.ZipOutputStream;
import org.json.JSONObject;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.web.client.RestTemplate;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.utility.AttributeParser;
import ch.qos.logback.classic.Logger;
public class ExternalCodeGenieSolver extends StepSolver {
private String codegenie_input;
private String codegenie_base_url;
private String codegenie_username;
private String codegenie_password;
private String codegenie_output_variable;
private String codegenie_application;
private String codegenie_project;
private String codegenie_model_provider;
private String codegenie_output_type;
@Autowired
private ScenarioExecutionRepository scenarioExecutionRepo;
Logger logger = (Logger) LoggerFactory.getLogger(BasicQueryRagSolver.class);
private void loadParameters() {
logger.info("Loading parameters");
// che parametri in input deve ricevere?
if (this.step.getAttributes().get("codegenie_model_provider") != null) {
this.codegenie_model_provider = (String) this.step.getAttributes().get("codegenie_model_provider");
logger.info("codegenie_model_provider: " + this.codegenie_model_provider);
}
if (this.step.getAttributes().get("codegenie_input") != null) {
this.codegenie_input = (String) this.step.getAttributes().get("codegenie_input");
logger.info("codegenie_input: " + this.codegenie_input);
}
if (this.step.getAttributes().get("codegenie_application") != null) {
this.codegenie_application = (String) this.step.getAttributes().get("codegenie_application");
logger.info("codegenie_application: " + this.codegenie_application);
}
if (this.step.getAttributes().get("codegenie_project") != null) {
this.codegenie_project = (String) this.step.getAttributes().get("codegenie_project");
logger.info("codegenie_project: " + this.codegenie_project);
}
if (this.step.getAttributes().get("codegenie_base_url") != null) {
this.codegenie_base_url = (String) this.step.getAttributes().get("codegenie_base_url");
logger.info("codegenie_base_url: " + this.codegenie_base_url);
}
if (this.step.getAttributes().get("codegenie_username") != null) {
this.codegenie_username = (String) this.step.getAttributes().get("codegenie_username");
}
if (this.step.getAttributes().get("codegenie_password") != null) {
this.codegenie_password = (String) this.step.getAttributes().get("codegenie_password");
}
if (this.step.getAttributes().get("codegenie_output_variable") != null) {
this.codegenie_output_variable = (String) this.step.getAttributes().get("codegenie_output_variable");
logger.info("codegenie_output_variable: " + this.codegenie_output_variable);
}
if (this.step.getAttributes().get("codegenie_output_type") != null) {
this.codegenie_output_type = (String) this.step.getAttributes().get("codegenie_output_type");
logger.info("codegenie_output_type: " + this.codegenie_output_type);
}
AttributeParser attributeParser = new AttributeParser(this.scenarioExecution);
this.codegenie_input = attributeParser.parse((String) this.step.getAttributes().get("codegenie_input"));
this.codegenie_application = attributeParser
.parse((String) this.step.getAttributes().get("codegenie_application"));
this.codegenie_project = attributeParser.parse((String) this.step.getAttributes().get("codegenie_project"));
}
@Override
public ScenarioExecution solveStep() throws Exception {
try {
System.out.println("Solving step: " + this.step.getName());
this.scenarioExecution.setCurrentStepId(this.step.getStepId());
loadParameters();
// token
HttpHeaders headersToken = new HttpHeaders();
headersToken.setContentType(MediaType.APPLICATION_JSON);
// headers.set("Authorization", "Bearer ");
String auth = codegenie_username + ":" + codegenie_password;
String encodedAuth = java.util.Base64.getEncoder().encodeToString(auth.getBytes());
headersToken.set("Authorization", "Basic " + encodedAuth);
HttpEntity<Void> entityToken = new HttpEntity<>(headersToken);
RestTemplate restTemplate = new RestTemplate();
ResponseEntity<String> response = restTemplate.exchange(
this.codegenie_base_url + "/token",
HttpMethod.POST,
entityToken,
String.class);
JSONObject jsonResponse = new JSONObject(response.getBody());
List<String> output_type = new ArrayList<String>();
output_type.add(this.codegenie_output_type);
JSONObject requestBody = new JSONObject();
requestBody.put("execution_id", this.scenarioExecution.getId());
requestBody.put("docs_zip_file", this.codegenie_input);
requestBody.put("application", this.codegenie_application);
requestBody.put("project", this.codegenie_project);
requestBody.put("output_type", output_type);
requestBody.put("output_variable", this.codegenie_output_variable);
requestBody.put("model_provider", this.codegenie_model_provider);
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
headers.set("Authorization",
jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString());
HttpEntity<String> request = new HttpEntity<>(requestBody.toString(), headers);
response = restTemplate.exchange(
this.codegenie_base_url + "/execute",
HttpMethod.POST,
request,
String.class);
jsonResponse = new JSONObject(response.getBody());
/* response = restTemplate.exchange(
//this.codegenie_base_url + "/execution_status/" + this.scenarioExecution.getId(),
this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17",
HttpMethod.GET,
request,
String.class);
jsonResponse = new JSONObject(response.getBody());*/
int maxTries = 500;
// Pool the status GET api until it return the SUCCESS or FAILED message
while (!jsonResponse.get("status").equals("DONE") && !jsonResponse.get("status").equals("ERROR")
&& maxTries > 0) {
response = restTemplate.exchange(
this.codegenie_base_url + "/token",
HttpMethod.POST,
entityToken,
String.class);
jsonResponse = new JSONObject(response.getBody());
headers.set("Authorization",
jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString());
request = new HttpEntity<>(requestBody.toString(), headers);
response = restTemplate.exchange(
this.codegenie_base_url + "/execution_status/" + this.scenarioExecution.getId(),
//this.codegenie_base_url + "/execution_status/679752f85189eb5621b48e17",
HttpMethod.GET,
request,
String.class);
jsonResponse = new JSONObject(response.getBody());
try {
Thread.sleep(120000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
logger.error("Thread was interrupted", e);
}
logger.info("Check status => Remaining tryes :" + maxTries);
logger.info("Percent: " + jsonResponse.get("percent"));
maxTries--;
}
logger.info("Stop pooling codegenies pod. Latest status = " + jsonResponse.get("status"));
if (jsonResponse.get("status").equals("DONE")) {
response = restTemplate.exchange(
this.codegenie_base_url + "/token",
HttpMethod.POST,
entityToken,
String.class);
jsonResponse = new JSONObject(response.getBody());
headers.set("Authorization",
jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString());
request = new HttpEntity<>(requestBody.toString(), headers);
response = restTemplate.exchange(
this.codegenie_base_url + "/execution_result/" +this.scenarioExecution.getId(),
//this.codegenie_base_url + "/execution_result/679752f85189eb5621b48e17",
HttpMethod.GET,
request,
String.class);
jsonResponse = new JSONObject(response.getBody());
if (jsonResponse.get("status").equals("DONE")) {
// Accedi all'oggetto "outputs"
JSONObject outputs = jsonResponse.getJSONObject("outputs");
// Ottieni il valore del campo "json"
String jsonPath = "";
if (this.codegenie_output_type.equals("FILE")) {
jsonPath = outputs.getString("file");
} else if (this.codegenie_output_type.equals("MARKDOWN")) {
jsonPath = outputs.getString("markdown");
// this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable,
// jsonResponse.get("OUTPUT").toString());
} else if (this.codegenie_output_type.equals("JSON")) {
jsonPath = outputs.getString("json");
// this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable,
// jsonResponse.get("OUTPUT").toString());
}
response = restTemplate.exchange(
this.codegenie_base_url + "/token",
HttpMethod.POST,
entityToken,
String.class);
jsonResponse = new JSONObject(response.getBody());
headers.set("Authorization",
jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString());
request = new HttpEntity<>(requestBody.toString(), headers);
ResponseEntity<byte[]> responseFile = restTemplate.exchange(
this.codegenie_base_url + "/download/" + jsonPath,
HttpMethod.GET,
request,
byte[].class);
// Salva i bytes in un file
byte[] fileBytes = responseFile.getBody();
String fileBase64 = Base64.getEncoder().encodeToString(fileBytes);
this.scenarioExecution.getExecSharedMap().put(this.codegenie_output_variable,
fileBase64);
this.scenarioExecution.setNextStepId(this.step.getNextStepId());
this.scenarioExecution.getExecSharedMap().put("status",
"DONE");
} else {
throw new Exception(
"codegenie execution failed with status: " + jsonResponse.get("status").toString());
}
} else {
logger.error("ERROR on pooling codegenies");
throw new Exception("codegenie execution failed with status: " + jsonResponse.get("status").toString());
}
response = restTemplate.exchange(
this.codegenie_base_url + "/token",
HttpMethod.POST,
entityToken,
String.class);
jsonResponse = new JSONObject(response.getBody());
headers.set("Authorization",
jsonResponse.get("token_type").toString() + " " + jsonResponse.get("access_token").toString());
request = new HttpEntity<>(requestBody.toString(), headers);
response = restTemplate.exchange(
this.codegenie_base_url + "/execution_consumption/" + this.scenarioExecution.getId(),
//this.codegenie_base_url + "/execution_consumption/679752f85189eb5621b48e17",
HttpMethod.GET,
request,
String.class);
jsonResponse = new JSONObject(response.getBody());
this.scenarioExecution.getExecSharedMap().put("input_token", jsonResponse.get("input_token").toString());
this.scenarioExecution.getExecSharedMap().put("output_token", jsonResponse.get("output_token").toString());
this.scenarioExecution.getExecSharedMap().put("total_token", jsonResponse.get("total_token").toString());
this.scenarioExecution.getExecSharedMap().put("used_model", jsonResponse.get("used_model").toString());
// scenarioExecutionRepo.save(this.scenarioExecution);
} catch (Exception e) {
logger.error("Error in codegenie execution: " + e.getMessage(), e);
this.scenarioExecution.getExecSharedMap().put("status",
"ERROR");
}
return this.scenarioExecution;
}
// Metodo per scaricare il file dal path e convertirlo in Base64
private String downloadAndConvertToBase64(String filePath) throws Exception {
// Creare un file ZIP temporaneo
File zipFile = File.createTempFile("temp", ".zip");
try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(zipFile));
FileInputStream fis = new FileInputStream(filePath)) {
// Aggiungere il file al file ZIP
ZipEntry zipEntry = new ZipEntry(new File(filePath).getName());
zos.putNextEntry(zipEntry);
byte[] buffer = new byte[1024];
int length;
while ((length = fis.read(buffer)) > 0) {
zos.write(buffer, 0, length);
}
zos.closeEntry();
}
// Leggere il contenuto del file ZIP e convertirlo in Base64
byte[] zipBytes = Files.readAllBytes(zipFile.toPath());
return Base64.getEncoder().encodeToString(zipBytes);
}
}