diff --git a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatMessage.java b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatMessage.java index f1d06fc..8778865 100644 --- a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatMessage.java +++ b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatMessage.java @@ -13,130 +13,133 @@ @JsonInclude(JsonInclude.Include.NON_NULL) public class ChatMessage { - private ChatMessageRole role; - - private String content; - - /** - * the name of the author of this message - */ - private String name; - - @JsonProperty("tool_calls") - private List toolCalls; - - public ChatMessage() { - } - - public ChatMessage(ChatMessageRole role, String content) { - this.role = role; - this.content = content; - } - - - public ChatMessageRole getRole() { - return role; - } - - public void setRole(ChatMessageRole role) { - this.role = role; - } - - public String getContent() { - return content; - } - - public void setContent(String content) { - this.content = content; - } - - public String getName() { - return name; - } - - public void setName(String name) { - this.name = name; - } - - public List getToolCalls() { - return toolCalls; - } - - public void setToolCalls(List toolCalls) { - this.toolCalls = toolCalls; - } - - public FunctionCall findFunctionCall() { - if (toolCalls != null && !toolCalls.isEmpty()) { - for (ToolCall toolCall : toolCalls) { - if (Objects.equals(toolCall.getType(), "function")) { - return toolCall.getFunction(); - } - } - } - return null; - } - - @JsonIgnore - public Mono getReplyCombinedText() { - if (content != null) { - return Mono.just(content); - } - if (toolCalls != null) { - for (ToolCall toolCall : toolCalls) { - final FunctionCall functionCall = toolCall.getFunction(); - if (functionCall.getFunctionStub() != null) { - try { - final Object result = functionCall.getFunctionStub().call(); - if (result != null) { - if (result instanceof Mono) { - return ((Mono) result).map(GPTFunctionUtils::toTextPlain); - } else { - return Mono.just(result).map(GPTFunctionUtils::toTextPlain); - } - } - } catch (Exception e) { - return Mono.error(e); - } - } - } - - } - return Mono.empty(); - } - - @JsonIgnore - public Mono getFunctionResult() { - if (toolCalls != null) { - final FunctionCall functionCall = toolCalls.get(0).getFunction(); - if (functionCall.getFunctionStub() != null) { - try { - final Object result = functionCall.getFunctionStub().call(); - if (result != null) { - if (result instanceof Mono) { - return (Mono) result; - } else { - return (Mono) Mono.just(result); - } - } - } catch (Exception e) { - return Mono.error(e); - } - } - } - return Mono.empty(); - } - - public static ChatMessage systemMessage(@Nonnull String content) { - return new ChatMessage(ChatMessageRole.system, content); - } - - public static ChatMessage userMessage(@Nonnull String content) { - return new ChatMessage(ChatMessageRole.user, content); - } - - public static ChatMessage assistantMessage(@Nonnull String content) { - return new ChatMessage(ChatMessageRole.assistant, content); - } + private ChatMessageRole role; + + private String content; + + /** + * the name of the author of this message + */ + private String name; + + @JsonProperty("tool_calls") + private List toolCalls; + + public ChatMessage() { + } + + public ChatMessage(ChatMessageRole role, String content) { + this.role = role; + this.content = content; + } + + public ChatMessageRole getRole() { + return role; + } + + public void setRole(ChatMessageRole role) { + this.role = role; + } + + public String getContent() { + return content; + } + + public void setContent(String content) { + this.content = content; + } + + public String getName() { + return name; + } + + public void setName(String name) { + this.name = name; + } + + public List getToolCalls() { + return toolCalls; + } + + public void setToolCalls(List toolCalls) { + this.toolCalls = toolCalls; + } + + public FunctionCall findFunctionCall() { + if (toolCalls != null && !toolCalls.isEmpty()) { + for (ToolCall toolCall : toolCalls) { + if (Objects.equals(toolCall.getType(), "function")) { + return toolCall.getFunction(); + } + } + } + return null; + } + + @JsonIgnore + public Mono getReplyCombinedText() { + if (content != null) { + return Mono.just(content); + } + if (toolCalls != null) { + for (ToolCall toolCall : toolCalls) { + final FunctionCall functionCall = toolCall.getFunction(); + if (functionCall.getFunctionStub() != null) { + try { + final Object result = functionCall.getFunctionStub().call(); + if (result != null) { + if (result instanceof Mono) { + return ((Mono) result).map(GPTFunctionUtils::toTextPlain); + } + else { + return Mono.just(result).map(GPTFunctionUtils::toTextPlain); + } + } + } + catch (Exception e) { + return Mono.error(e); + } + } + } + + } + return Mono.empty(); + } + + @JsonIgnore + public Mono getFunctionResult() { + if (toolCalls != null) { + final FunctionCall functionCall = toolCalls.get(0).getFunction(); + if (functionCall.getFunctionStub() != null) { + try { + final Object result = functionCall.getFunctionStub().call(); + if (result != null) { + if (result instanceof Mono) { + return (Mono) result; + } + else { + return (Mono) Mono.just(result); + } + } + } + catch (Exception e) { + return Mono.error(e); + } + } + } + return Mono.empty(); + } + + public static ChatMessage systemMessage(@Nonnull String content) { + return new ChatMessage(ChatMessageRole.system, content); + } + + public static ChatMessage userMessage(@Nonnull String content) { + return new ChatMessage(ChatMessageRole.user, content); + } + + public static ChatMessage assistantMessage(@Nonnull String content) { + return new ChatMessage(ChatMessageRole.assistant, content); + } } diff --git a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatTool.java b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatTool.java index 0b498ad..9e6ba16 100644 --- a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatTool.java +++ b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ChatTool.java @@ -1,30 +1,32 @@ package org.mvnsearch.chatgpt.model.completion.chat; - public class ChatTool { - private String type = "function"; - private ChatFunction function; - public ChatTool() { - } + private String type = "function"; + + private ChatFunction function; + + public ChatTool() { + } + + public ChatTool(ChatFunction function) { + this.function = function; + } - public ChatTool(ChatFunction function) { - this.function = function; - } + public String getType() { + return type; + } - public String getType() { - return type; - } + public void setType(String type) { + this.type = type; + } - public void setType(String type) { - this.type = type; - } + public ChatFunction getFunction() { + return function; + } - public ChatFunction getFunction() { - return function; - } + public void setFunction(ChatFunction function) { + this.function = function; + } - public void setFunction(ChatFunction function) { - this.function = function; - } } diff --git a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ToolCall.java b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ToolCall.java index c8898fe..9883e2e 100644 --- a/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ToolCall.java +++ b/src/main/java/org/mvnsearch/chatgpt/model/completion/chat/ToolCall.java @@ -1,34 +1,38 @@ package org.mvnsearch.chatgpt.model.completion.chat; public class ToolCall { - private String id; - private String type; - private FunctionCall function; - public ToolCall() { - } + private String id; - public String getId() { - return id; - } + private String type; - public void setId(String id) { - this.id = id; - } + private FunctionCall function; - public String getType() { - return type; - } + public ToolCall() { + } - public void setType(String type) { - this.type = type; - } + public String getId() { + return id; + } - public FunctionCall getFunction() { - return function; - } + public void setId(String id) { + this.id = id; + } + + public String getType() { + return type; + } + + public void setType(String type) { + this.type = type; + } + + public FunctionCall getFunction() { + return function; + } + + public void setFunction(FunctionCall function) { + this.function = function; + } - public void setFunction(FunctionCall function) { - this.function = function; - } } diff --git a/src/main/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImpl.java b/src/main/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImpl.java index f399bdd..5e19e0f 100644 --- a/src/main/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImpl.java +++ b/src/main/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImpl.java @@ -18,120 +18,122 @@ class ChatGPTServiceImpl implements ChatGPTService { - private final OpenAIChatAPI openAIChatAPI; - - private final PromptManager promptManager; - - private final GPTFunctionRegistry registry; - - private String model = ChatGPTConstants.DEFAULT_MODEL; - - ChatGPTServiceImpl(OpenAIChatAPI openAIChatAPI, PromptManager promptManager, GPTFunctionRegistry registry) - throws Exception { - this.openAIChatAPI = openAIChatAPI; - this.promptManager = promptManager; - this.registry = registry; - } - - public void setModel(String model) { - this.model = model; - } - - @Override - public Mono chat(ChatCompletionRequest request) { - buildChatCompletionRequest(request); - request.setStream(null); - boolean functionsIncluded = request.getTools() != null; - if (!functionsIncluded) { - return this.openAIChatAPI// - .chat(request); - } // - else { - return this.openAIChatAPI.chat(request)// - .doOnNext(response -> { - for (ChatMessage chatMessage : response.getReply()) { - injectFunctionCallLambda(chatMessage); - } - }); - } - } - - @Override - public Flux stream(ChatCompletionRequest request) { - buildChatCompletionRequest(request); - request.setStream(true); - boolean functionsIncluded = request.getTools() != null; - if (!functionsIncluded) { - return openAIChatAPI.stream(request).onErrorContinue((e, obj) -> { - }); - } else { - return openAIChatAPI.stream(request).onErrorContinue((e, obj) -> { - }).doOnNext(response -> { - for (ChatMessage chatMessage : response.getReply()) { - injectFunctionCallLambda(chatMessage); - } - }); - } - } - - @Override - public Mono embed(EmbeddingsRequest request) { - return this.openAIChatAPI.embed(request); - } - - @Override - public Mono complete(CompletionRequest request) { - return this.openAIChatAPI.complete(request); - } - - @Override - public Function> promptAsLambda(@PropertyKey(resourceBundle = PROMPTS_FQN) String promptKey) { - return promptAsLambda(promptKey, null); - } - - @Override - public Function> promptAsLambda(@PropertyKey(resourceBundle = PROMPTS_FQN) String promptKey, - String functionName) { - return obj -> { - String prompt = promptManager.prompt(promptKey, obj); - final ChatCompletionRequest request = ChatRequestBuilder.of(prompt).model(model).build(); - if (functionName != null && !functionName.isEmpty()) { - request.addFunction(functionName); - return (Mono) chat(request).flatMap(ChatCompletionResponse::getFunctionResult); - } else { - return (Mono) chat(request).map(ChatCompletionResponse::getReplyText); - } - }; - } - - private void buildChatCompletionRequest(ChatCompletionRequest request) { - if (request.getModel() == null) { - request.setModel(model); - } - injectFunctions(request); - } - - private void injectFunctions(ChatCompletionRequest request) { - final List functionNames = request.getFunctionNames(); - if (functionNames != null && !functionNames.isEmpty()) { - for (String functionName : functionNames) { - ChatFunction chatFunction = this.registry.getChatFunction(functionName); - if (chatFunction != null) { - request.addFunction(chatFunction); - } - } - } - } - - private void injectFunctionCallLambda(ChatMessage chatMessage) { - final FunctionCall functionCall = chatMessage.findFunctionCall(); - if (functionCall != null) { - final String functionName = functionCall.getName(); - ChatGPTJavaFunction jsonSchemaFunction = registry.getJsonSchemaFunction(functionName); - if (jsonSchemaFunction != null) { - functionCall.setFunctionStub(() -> jsonSchemaFunction.call(functionCall.getArguments())); - } - } - } + private final OpenAIChatAPI openAIChatAPI; + + private final PromptManager promptManager; + + private final GPTFunctionRegistry registry; + + private String model = ChatGPTConstants.DEFAULT_MODEL; + + ChatGPTServiceImpl(OpenAIChatAPI openAIChatAPI, PromptManager promptManager, GPTFunctionRegistry registry) + throws Exception { + this.openAIChatAPI = openAIChatAPI; + this.promptManager = promptManager; + this.registry = registry; + } + + public void setModel(String model) { + this.model = model; + } + + @Override + public Mono chat(ChatCompletionRequest request) { + buildChatCompletionRequest(request); + request.setStream(null); + boolean functionsIncluded = request.getTools() != null; + if (!functionsIncluded) { + return this.openAIChatAPI// + .chat(request); + } // + else { + return this.openAIChatAPI.chat(request)// + .doOnNext(response -> { + for (ChatMessage chatMessage : response.getReply()) { + injectFunctionCallLambda(chatMessage); + } + }); + } + } + + @Override + public Flux stream(ChatCompletionRequest request) { + buildChatCompletionRequest(request); + request.setStream(true); + boolean functionsIncluded = request.getTools() != null; + if (!functionsIncluded) { + return openAIChatAPI.stream(request).onErrorContinue((e, obj) -> { + }); + } + else { + return openAIChatAPI.stream(request).onErrorContinue((e, obj) -> { + }).doOnNext(response -> { + for (ChatMessage chatMessage : response.getReply()) { + injectFunctionCallLambda(chatMessage); + } + }); + } + } + + @Override + public Mono embed(EmbeddingsRequest request) { + return this.openAIChatAPI.embed(request); + } + + @Override + public Mono complete(CompletionRequest request) { + return this.openAIChatAPI.complete(request); + } + + @Override + public Function> promptAsLambda(@PropertyKey(resourceBundle = PROMPTS_FQN) String promptKey) { + return promptAsLambda(promptKey, null); + } + + @Override + public Function> promptAsLambda(@PropertyKey(resourceBundle = PROMPTS_FQN) String promptKey, + String functionName) { + return obj -> { + String prompt = promptManager.prompt(promptKey, obj); + final ChatCompletionRequest request = ChatRequestBuilder.of(prompt).model(model).build(); + if (functionName != null && !functionName.isEmpty()) { + request.addFunction(functionName); + return (Mono) chat(request).flatMap(ChatCompletionResponse::getFunctionResult); + } + else { + return (Mono) chat(request).map(ChatCompletionResponse::getReplyText); + } + }; + } + + private void buildChatCompletionRequest(ChatCompletionRequest request) { + if (request.getModel() == null) { + request.setModel(model); + } + injectFunctions(request); + } + + private void injectFunctions(ChatCompletionRequest request) { + final List functionNames = request.getFunctionNames(); + if (functionNames != null && !functionNames.isEmpty()) { + for (String functionName : functionNames) { + ChatFunction chatFunction = this.registry.getChatFunction(functionName); + if (chatFunction != null) { + request.addFunction(chatFunction); + } + } + } + } + + private void injectFunctionCallLambda(ChatMessage chatMessage) { + final FunctionCall functionCall = chatMessage.findFunctionCall(); + if (functionCall != null) { + final String functionName = functionCall.getName(); + ChatGPTJavaFunction jsonSchemaFunction = registry.getJsonSchemaFunction(functionName); + if (jsonSchemaFunction != null) { + functionCall.setFunctionStub(() -> jsonSchemaFunction.call(functionCall.getArguments())); + } + } + } } diff --git a/src/test/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImplTest.java b/src/test/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImplTest.java index 6036187..873bf73 100644 --- a/src/test/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImplTest.java +++ b/src/test/java/org/mvnsearch/chatgpt/spring/service/ChatGPTServiceImplTest.java @@ -19,101 +19,102 @@ class ChatGPTServiceImplTest extends ProjectBootBaseTest { - @Autowired - private ChatGPTService chatGPTService; - - @Autowired - private PromptManager promptManager; - - @Test - void testSimpleChat() { - ChatCompletionRequest request = ChatCompletionRequest.of("What's Java Language?"); - ChatCompletionResponse response = chatGPTService.chat(request).block(); - System.out.println(response.getReplyText()); - } - - @Test - public void testEmbed() { - final EmbeddingsResponse response = chatGPTService.embed(EmbeddingsRequest.of("Hello World")).block(); - assertThat(response).isNotNull(); - assertThat(response.getEmbeddings()).isNotEmpty(); - System.out.println(response.getEmbeddings()); - } - - @Test - public void testComplete() { - final CompletionResponse response = chatGPTService.complete(new CompletionRequest("text-davinci-003", "床前明月光")) - .block(); - assertThat(response).isNotNull(); - System.out.println(response.getReplyText()); - } - - @Test - void testExecuteSQLQuery() throws Exception { - String prompt = "Query all employees whose salary is greater than the average."; - ChatCompletionRequest request = ChatRequestBuilder.of("You are a helpful SQL writer with MySQL dialect.", prompt) - .function("execute_sql_query") - .build(); - String result = chatGPTService.chat(request).flatMap(ChatCompletionResponse::getReplyCombinedText).block(); - System.out.println(result); - } - - @Test - void testChatWithFunctions() throws Exception { - String prompt = "Hi Jackie, could you write an email to Libing(libing.chen@gmail.com) and Sam(linux_china@hotmail.com) and invite them to join Mike's birthday party at 4 pm tomorrow? Thanks!"; - ChatCompletionRequest request = ChatRequestBuilder.of(prompt).function("send_email").build(); - ChatCompletionResponse response = chatGPTService.chat(request).block(); - // display reply combined text with function call - System.out.println(response.getReplyCombinedText().block()); - // call function manually - /* - * for (ChatMessage chatMessage : response.getReply()) { final FunctionCall - * functionCall = chatMessage.getFunctionCall(); if (functionCall != null) { final - * Object result = functionCall.getFunctionStub().call(); - * System.out.println(result); } } - */ - } - - @Test - void testSmartSpeaker() throws Exception { - String prompt = "You are a smart speaker, and your name is Alexa. You can play music songs, answer questions and so on. \nAlexa, please play Hotel California."; - ChatCompletionRequest request = ChatRequestBuilder.of(prompt).function("play_songs").build(); - ChatCompletionResponse response = chatGPTService.chat(request).block(); - // display reply combined text with function call - System.out.println(response.getReplyCombinedText().block()); - } - - @Test - void testPromptAsFunction() { - Function> translateIntoChineseFunction = chatGPTService - .promptAsLambda("translate-into-chinese"); - Function> sendEmailFunction = chatGPTService.promptAsLambda("send-email", "send_email"); - String result = Mono.just( - "Hi Jackie, could you write an email to Libing(libing.chen@exaple.com) and Sam(linux_china@example.com) and invite them to join Mike's birthday party at 4 pm tomorrow? Thanks!") - .flatMap(translateIntoChineseFunction) - .flatMap(sendEmailFunction) - .block(); - System.out.println(result); - } - - record TranslateRequest(String from, String to, String text) { - } - - @Test - void testLambdaWithRecord() { - Function> translateFunction = chatGPTService.promptAsLambda("translate"); - String result = Mono.just(new TranslateRequest("Chinese", "English", "你好!")).flatMap(translateFunction).block(); - System.out.println(result); - } - - @Test - void testLambdaWithFunctionResult() { - Function>> executeSqlQuery = chatGPTService.promptAsLambda("sql-writer", - "execute_sql_query"); - List result = Mono.just("Query all employees whose salary is greater than the average.") - .flatMap(executeSqlQuery) - .block(); - assertThat(result).isNotEmpty(); - } + @Autowired + private ChatGPTService chatGPTService; + + @Autowired + private PromptManager promptManager; + + @Test + void testSimpleChat() { + ChatCompletionRequest request = ChatCompletionRequest.of("What's Java Language?"); + ChatCompletionResponse response = chatGPTService.chat(request).block(); + System.out.println(response.getReplyText()); + } + + @Test + public void testEmbed() { + final EmbeddingsResponse response = chatGPTService.embed(EmbeddingsRequest.of("Hello World")).block(); + assertThat(response).isNotNull(); + assertThat(response.getEmbeddings()).isNotEmpty(); + System.out.println(response.getEmbeddings()); + } + + @Test + public void testComplete() { + final CompletionResponse response = chatGPTService.complete(new CompletionRequest("text-davinci-003", "床前明月光")) + .block(); + assertThat(response).isNotNull(); + System.out.println(response.getReplyText()); + } + + @Test + void testExecuteSQLQuery() throws Exception { + String prompt = "Query all employees whose salary is greater than the average."; + ChatCompletionRequest request = ChatRequestBuilder + .of("You are a helpful SQL writer with MySQL dialect.", prompt) + .function("execute_sql_query") + .build(); + String result = chatGPTService.chat(request).flatMap(ChatCompletionResponse::getReplyCombinedText).block(); + System.out.println(result); + } + + @Test + void testChatWithFunctions() throws Exception { + String prompt = "Hi Jackie, could you write an email to Libing(libing.chen@gmail.com) and Sam(linux_china@hotmail.com) and invite them to join Mike's birthday party at 4 pm tomorrow? Thanks!"; + ChatCompletionRequest request = ChatRequestBuilder.of(prompt).function("send_email").build(); + ChatCompletionResponse response = chatGPTService.chat(request).block(); + // display reply combined text with function call + System.out.println(response.getReplyCombinedText().block()); + // call function manually + /* + * for (ChatMessage chatMessage : response.getReply()) { final FunctionCall + * functionCall = chatMessage.getFunctionCall(); if (functionCall != null) { final + * Object result = functionCall.getFunctionStub().call(); + * System.out.println(result); } } + */ + } + + @Test + void testSmartSpeaker() throws Exception { + String prompt = "You are a smart speaker, and your name is Alexa. You can play music songs, answer questions and so on. \nAlexa, please play Hotel California."; + ChatCompletionRequest request = ChatRequestBuilder.of(prompt).function("play_songs").build(); + ChatCompletionResponse response = chatGPTService.chat(request).block(); + // display reply combined text with function call + System.out.println(response.getReplyCombinedText().block()); + } + + @Test + void testPromptAsFunction() { + Function> translateIntoChineseFunction = chatGPTService + .promptAsLambda("translate-into-chinese"); + Function> sendEmailFunction = chatGPTService.promptAsLambda("send-email", "send_email"); + String result = Mono.just( + "Hi Jackie, could you write an email to Libing(libing.chen@exaple.com) and Sam(linux_china@example.com) and invite them to join Mike's birthday party at 4 pm tomorrow? Thanks!") + .flatMap(translateIntoChineseFunction) + .flatMap(sendEmailFunction) + .block(); + System.out.println(result); + } + + record TranslateRequest(String from, String to, String text) { + } + + @Test + void testLambdaWithRecord() { + Function> translateFunction = chatGPTService.promptAsLambda("translate"); + String result = Mono.just(new TranslateRequest("Chinese", "English", "你好!")).flatMap(translateFunction).block(); + System.out.println(result); + } + + @Test + void testLambdaWithFunctionResult() { + Function>> executeSqlQuery = chatGPTService.promptAsLambda("sql-writer", + "execute_sql_query"); + List result = Mono.just("Query all employees whose salary is greater than the average.") + .flatMap(executeSqlQuery) + .block(); + assertThat(result).isNotEmpty(); + } }