diff --git a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java index 54d0e1ef..185c61d4 100644 --- a/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java +++ b/java/src/main/java/com/baidubce/appbuilder/model/appbuilderclient/AppBuilderClientRunRequest.java @@ -1,6 +1,8 @@ package com.baidubce.appbuilder.model.appbuilderclient; import java.util.Map; + +import com.google.gson.Gson; import com.google.gson.annotations.SerializedName; public class AppBuilderClientRunRequest { @@ -18,6 +20,25 @@ public class AppBuilderClientRunRequest { @SerializedName("tool_choice") private ToolChoice ToolChoice; + public AppBuilderClientRunRequest() { + } + + public AppBuilderClientRunRequest(String appID) { + this.appId = appID; + } + + public AppBuilderClientRunRequest(String appID, String conversationID) { + this.appId = appID; + this.conversationID = conversationID; + } + + public AppBuilderClientRunRequest(String appID, String conversationID, String query, Boolean stream) { + this.appId = appID; + this.conversationID = conversationID; + this.query = query; + this.stream = stream; + } + public String getAppId() { return appId; } @@ -66,6 +87,12 @@ public void setTools(Tool[] tools) { this.tools = tools; } + public void setTools(String toolJson) { + Gson gson = new Gson(); + Tool tool = gson.fromJson(toolJson, Tool.class); + this.tools = new Tool[] {tool}; + } + public ToolOutput[] getToolOutputs() { return ToolOutputs; } @@ -74,6 +101,11 @@ public void setToolOutputs(ToolOutput[] toolOutputs) { this.ToolOutputs = toolOutputs; } + public void setToolOutputs(String toolCallID, String outputString) { + ToolOutput output = new ToolOutput(toolCallID, outputString); + this.ToolOutputs = new ToolOutput[] { output }; + } + public ToolChoice getToolChoice() { return ToolChoice; } diff --git a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java index bd2c640d..3e8cad3d 100644 --- a/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java +++ b/java/src/test/java/com/baidubce/appbuilder/AppBuilderClientTest.java @@ -12,6 +12,7 @@ import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult; import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest; +import com.google.gson.Gson; import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientRunRequest; import org.junit.Before; import org.junit.Test; @@ -65,44 +66,37 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx AppBuilderClient builder = new AppBuilderClient(appId); String conversationId = builder.createConversation(); assertNotNull(conversationId); - String fileId = builder.uploadLocalFile(conversationId, - "src/test/java/com/baidubce/appbuilder/files/test.pdf"); - assertNotNull(fileId); - AppBuilderClientRunRequest request = new AppBuilderClientRunRequest(); - request.setAppId(appId); - request.setConversationID(conversationId); - request.setQuery("今天北京的天气怎么样?"); - request.setStream(false); - - String name = "get_cur_whether"; - String desc = "这是一个获得指定地点天气的工具"; - Map parameters = new HashMap<>(); - - Map location = new HashMap<>(); - location.put("type", "string"); - location.put("description", "省,市名,例如:河北省"); - - Map unit = new HashMap<>(); - unit.put("type", "string"); - List enumValues = new ArrayList<>(); - enumValues.add("摄氏度"); - enumValues.add("华氏度"); - unit.put("enum", enumValues); - - Map properties = new HashMap<>(); - properties.put("location", location); - properties.put("unit", unit); - - parameters.put("type", "object"); - parameters.put("properties", properties); - List required = new ArrayList<>(); - required.add("location"); - parameters.put("required", required); - - AppBuilderClientRunRequest.Tool.Function func = new AppBuilderClientRunRequest.Tool.Function(name, desc, - parameters); - AppBuilderClientRunRequest.Tool tool = new AppBuilderClientRunRequest.Tool("function", func); - request.setTools(new AppBuilderClientRunRequest.Tool[] { tool }); + + AppBuilderClientRunRequest request = new AppBuilderClientRunRequest(appId, conversationId, "今天北京的天气怎么样?", false); + + // 如果你本地的java版本更高,可以使用文本块特性,简化字符串构造 + String toolJson = "{\n" + + " \"type\": \"function\",\n" + + " \"function\": {\n" + + " \"name\": \"get_cur_whether\",\n" + + " \"description\": \"这是一个获得指定地点天气的工具\",\n" + + " \"parameters\": {\n" + + " \"type\": \"object\",\n" + + " \"properties\": {\n" + + " \"location\": {\n" + + " \"type\": \"string\",\n" + + " \"description\": \"省,市名,例如:河北省\"\n" + + " },\n" + + " \"unit\": {\n" + + " \"type\": \"string\",\n" + + " \"enum\": [\n" + + " \"摄氏度\",\n" + + " \"华氏度\"\n" + + " ]\n" + + " }\n" + + " },\n" + + " \"required\": [\n" + + " \"location\"\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"; + request.setTools(toolJson); AppBuilderClientIterator itor = builder.run(request); assertTrue(itor.hasNext()); @@ -113,12 +107,8 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx System.out.println(result); } - AppBuilderClientRunRequest request2 = new AppBuilderClientRunRequest(); - request2.setAppId(appId); - request2.setConversationID(conversationId); - - AppBuilderClientRunRequest.ToolOutput output = new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度"); - request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] { output }); + AppBuilderClientRunRequest request2 = new AppBuilderClientRunRequest(appId, conversationId); + request2.setToolOutputs(ToolCallID, "北京今天35度"); AppBuilderClientIterator itor2 = builder.run(request2); assertTrue(itor2.hasNext()); while (itor2.hasNext()) {