adding LLM choice for ChatModel

This commit is contained in:
2024-10-15 21:45:55 +02:00
parent b23521700f
commit a9bdbb4ed8
8 changed files with 154 additions and 8 deletions

View File

@@ -28,7 +28,8 @@
</scm>
<properties>
<java.version>21</java.version>
<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>
<!--<spring-ai.version>1.0.0-SNAPSHOT</spring-ai.version>-->
<spring-ai.version>1.0.0-M2</spring-ai.version>
<spring-cloud.version>2023.0.3</spring-cloud.version>
</properties>
<dependencies>
@@ -49,6 +50,10 @@
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-openai-spring-boot-starter</artifactId>
</dependency>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-azure-openai</artifactId>
</dependency>
<dependency>
<groupId>org.neo4j.driver</groupId>
<artifactId>neo4j-java-driver</artifactId>

View File

@@ -0,0 +1,35 @@
package com.olympus.hermione.configurations;
import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
@Configuration
public class HermioneConfig {
private final OpenAiEmbeddingModel myServiceLib1;
private final AzureOpenAiEmbeddingModel myServiceLib2;
// Autowiring beans from both libraries
public HermioneConfig(OpenAiEmbeddingModel myServiceLib1, AzureOpenAiEmbeddingModel myServiceLib2) {
this.myServiceLib1 = myServiceLib1;
this.myServiceLib2 = myServiceLib2;
}
// Re-declaring and marking one as primary
@Bean
@Primary
public EmbeddingModel primaryMyService() {
return myServiceLib1; // Choose the one you want as primary
}
// Optionally declare the other bean without primary
@Bean
public EmbeddingModel secondaryMyService() {
return myServiceLib2; // The other one
}
}

View File

@@ -0,0 +1,23 @@
package com.olympus.hermione.models;
import org.springframework.data.annotation.Id;
import org.springframework.data.mongodb.core.mapping.Document;
import lombok.Getter;
import lombok.Setter;
@Document(collection = "chat_model")
@Getter @Setter
public class AiModel {
@Id
private String id;
private String name;
private String endpoint;
private String apiProvider;
private String model;
private Float temperature;
private Integer maxTokens;
private String apiKey;
private boolean isDefault = false;
}

View File

@@ -18,6 +18,6 @@ public class Scenario {
private String startWithStepId;
private List<ScenarioStep> steps;
private List<ScenarioInputs> inputs;
private String modelId;
private boolean useChatMemory=false;
}

View File

@@ -0,0 +1,13 @@
package com.olympus.hermione.repository;
import org.springframework.stereotype.Repository;
import java.util.Optional;
import org.springframework.data.mongodb.repository.MongoRepository;
import com.olympus.hermione.models.AiModel;
@Repository
public interface AiModelRepository extends MongoRepository<AiModel, String>{
Optional <AiModel> findByIsDefault(Boolean isDefault);
}

View File

@@ -11,6 +11,7 @@ import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Service;
import com.olympus.hermione.dto.CanvasExecutionInput;
@@ -20,8 +21,13 @@ import com.olympus.hermione.dto.CanvasOutput;
@Service
public class CanvasExecutionService {
private final ChatModel chatModel;
@Autowired
ChatModel chatModel;
public CanvasExecutionService(@Qualifier("openAiChatModel") ChatModel chatModel) {
this.chatModel = chatModel;
}
private Logger logger = LoggerFactory.getLogger(CanvasExecutionService.class);

View File

@@ -5,13 +5,17 @@ import java.util.List;
import java.util.Optional;
import org.slf4j.LoggerFactory;
import org.springframework.ai.azure.openai.AzureOpenAiChatModel;
import org.springframework.ai.azure.openai.AzureOpenAiChatOptions;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.client.advisor.MessageChatMemoryAdvisor;
import org.springframework.ai.chat.client.advisor.SimpleLoggerAdvisor;
import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.memory.InMemoryChatMemory;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.VectorStore;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.cloud.client.discovery.DiscoveryClient;
@@ -20,9 +24,11 @@ import org.springframework.stereotype.Service;
import com.olympus.hermione.dto.ScenarioExecutionInput;
import com.olympus.hermione.dto.ScenarioOutput;
import com.olympus.hermione.models.AiModel;
import com.olympus.hermione.models.Scenario;
import com.olympus.hermione.models.ScenarioExecution;
import com.olympus.hermione.models.ScenarioStep;
import com.olympus.hermione.repository.AiModelRepository;
import com.olympus.hermione.repository.ScenarioExecutionRepository;
import com.olympus.hermione.repository.ScenarioRepository;
import com.olympus.hermione.stepSolvers.AdvancedAIPromptSolver;
@@ -35,6 +41,13 @@ import com.olympus.hermione.stepSolvers.StepSolver;
import org.neo4j.driver.Driver;
import org.slf4j.Logger;
import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
@Service
public class ScenarioExecutionService {
@@ -44,14 +57,17 @@ public class ScenarioExecutionService {
@Autowired
private ScenarioExecutionRepository scenarioExecutionRepository;
@Autowired
private AiModelRepository aiModelRepository;
@Autowired
VectorStore vectorStore;
@Autowired
DiscoveryClient discoveryClient;
@Autowired
ChatModel chatModel;
//@Autowired
//ChatModel chatModel;
@Autowired
Driver graphDriver;
@@ -60,7 +76,7 @@ public class ScenarioExecutionService {
Neo4JUitilityService neo4JUitilityService;
private ChatClient chatClient;
private ChatModel chatModel;
private Logger logger = LoggerFactory.getLogger(ScenarioExecutionService.class);
@@ -111,7 +127,15 @@ public class ScenarioExecutionService {
scenarioExecution.setExecSharedMap(execSharedMap);
scenarioExecutionRepository.save(scenarioExecution);
//TODO: Chatmodel should be initialized with the scenario specific infos
//TODO: should be initialized with the scenario specific infos
Optional<AiModel> o_aiModel;
if(scenario.getModelId() != null){
o_aiModel = aiModelRepository.findById(scenario.getModelId());
}else {
o_aiModel = aiModelRepository.findByIsDefault(true);
}
AiModel aiModel = o_aiModel.get();
ChatModel chatModel = createChatModel(aiModel);
if(scenario.isUseChatMemory()){
logger.info("Initializing chatClient with chat-memory advisor");
@@ -225,4 +249,39 @@ public class ScenarioExecutionService {
return scenarioOutput;
}
private ChatModel createChatModel(AiModel aiModel){
switch(aiModel.getApiProvider()){
case "AzureOpenAI":
OpenAIClient openAIClient = new OpenAIClientBuilder()
.credential(new AzureKeyCredential(aiModel.getApiKey()))
.endpoint(aiModel.getEndpoint())
.buildClient();
AzureOpenAiChatOptions openAIChatOptions = AzureOpenAiChatOptions.builder()
.withDeploymentName(aiModel.getModel())
.withTemperature(aiModel.getTemperature())
.withMaxTokens(aiModel.getMaxTokens())
.build();
AzureOpenAiChatModel azureOpenaichatModel = new AzureOpenAiChatModel(openAIClient, openAIChatOptions);
logger.info("AI model used: " + aiModel.getModel());
return azureOpenaichatModel;
case "OpenAI":
OpenAiApi openAiApi = new OpenAiApi(aiModel.getApiKey());
OpenAiChatOptions openAiChatOptions = OpenAiChatOptions.builder()
.withModel(aiModel.getModel())
.withTemperature(aiModel.getTemperature())
.withMaxTokens(aiModel.getMaxTokens())
.build();
OpenAiChatModel openaichatModel = new OpenAiChatModel(openAiApi,openAiChatOptions);
logger.info("AI model used: " + aiModel.getModel());
return openaichatModel;
default:
throw new IllegalArgumentException("Unsupported AI model: " + aiModel.getName());
}
}
}

View File

@@ -14,6 +14,11 @@ spring.ai.openai.api-key=sk-proj-j3TFJ0h348DIzMrYYfyUT3BlbkFJjk4HMc8A2ux2Asg8Y7H
spring.ai.openai.chat.options.model=gpt-4o-mini
spring.main.allow-circular-references=true
spring.ai.azure.openai.api-key=YOUR_API_KEY
spring.ai.azure.openai.endpoint=YOUR_ENDPOINT
spring.ai.azure.openai.chat.options.deployment-name=gpt-4o
spring.ai.azure.openai.chat.options.temperature=0.7
neo4j.uri=neo4j+s://e17e6f08.databases.neo4j.io:7687
neo4j.username=neo4j
neo4j.password=8SrSqQ3q6q9PQNWtN9ozqSQfGce4lfh_n6kKz2JIubQ