Skip to content

Commit

Permalink
优化java tool_call调用链路
Browse files Browse the repository at this point in the history
  • Loading branch information
userpj committed Oct 24, 2024
1 parent e722de8 commit e6cc6f4
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 44 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.baidubce.appbuilder.model.appbuilderclient;

import java.util.Map;

import com.google.gson.Gson;
import com.google.gson.annotations.SerializedName;

public class AppBuilderClientRunRequest {
Expand All @@ -18,6 +20,25 @@ public class AppBuilderClientRunRequest {
@SerializedName("tool_choice")
private ToolChoice ToolChoice;

public AppBuilderClientRunRequest() {
}

public AppBuilderClientRunRequest(String appID) {
this.appId = appID;
}

public AppBuilderClientRunRequest(String appID, String conversationID) {
this.appId = appID;
this.conversationID = conversationID;
}

public AppBuilderClientRunRequest(String appID, String conversationID, String query, Boolean stream) {
this.appId = appID;
this.conversationID = conversationID;
this.query = query;
this.stream = stream;
}

public String getAppId() {
return appId;
}
Expand Down Expand Up @@ -66,6 +87,12 @@ public void setTools(Tool[] tools) {
this.tools = tools;
}

public void setTools(String toolJson) {
Gson gson = new Gson();
Tool tool = gson.fromJson(toolJson, Tool.class);
this.tools = new Tool[] {tool};
}

public ToolOutput[] getToolOutputs() {
return ToolOutputs;
}
Expand All @@ -74,6 +101,11 @@ public void setToolOutputs(ToolOutput[] toolOutputs) {
this.ToolOutputs = toolOutputs;
}

public void setToolOutputs(String toolCallID, String outputString) {
ToolOutput output = new ToolOutput(toolCallID, outputString);
this.ToolOutputs = new ToolOutput[] { output };
}

public ToolChoice getToolChoice() {
return ToolChoice;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientIterator;
import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientResult;
import com.baidubce.appbuilder.model.appbuilderclient.AppListRequest;
import com.google.gson.Gson;
import com.baidubce.appbuilder.model.appbuilderclient.AppBuilderClientRunRequest;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -65,44 +66,37 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx
AppBuilderClient builder = new AppBuilderClient(appId);
String conversationId = builder.createConversation();
assertNotNull(conversationId);
String fileId = builder.uploadLocalFile(conversationId,
"src/test/java/com/baidubce/appbuilder/files/test.pdf");
assertNotNull(fileId);
AppBuilderClientRunRequest request = new AppBuilderClientRunRequest();
request.setAppId(appId);
request.setConversationID(conversationId);
request.setQuery("今天北京的天气怎么样?");
request.setStream(false);

String name = "get_cur_whether";
String desc = "这是一个获得指定地点天气的工具";
Map<String, Object> parameters = new HashMap<>();

Map<String, Object> location = new HashMap<>();
location.put("type", "string");
location.put("description", "省,市名,例如:河北省");

Map<String, Object> unit = new HashMap<>();
unit.put("type", "string");
List<String> enumValues = new ArrayList<>();
enumValues.add("摄氏度");
enumValues.add("华氏度");
unit.put("enum", enumValues);

Map<String, Object> properties = new HashMap<>();
properties.put("location", location);
properties.put("unit", unit);

parameters.put("type", "object");
parameters.put("properties", properties);
List<String> required = new ArrayList<>();
required.add("location");
parameters.put("required", required);

AppBuilderClientRunRequest.Tool.Function func = new AppBuilderClientRunRequest.Tool.Function(name, desc,
parameters);
AppBuilderClientRunRequest.Tool tool = new AppBuilderClientRunRequest.Tool("function", func);
request.setTools(new AppBuilderClientRunRequest.Tool[] { tool });

AppBuilderClientRunRequest request = new AppBuilderClientRunRequest(appId, conversationId, "今天北京的天气怎么样?", false);

// 如果你本地的java版本更高,可以使用文本块特性,简化字符串构造
String toolJson = "{\n" +
" \"type\": \"function\",\n" +
" \"function\": {\n" +
" \"name\": \"get_cur_whether\",\n" +
" \"description\": \"这是一个获得指定地点天气的工具\",\n" +
" \"parameters\": {\n" +
" \"type\": \"object\",\n" +
" \"properties\": {\n" +
" \"location\": {\n" +
" \"type\": \"string\",\n" +
" \"description\": \"省,市名,例如:河北省\"\n" +
" },\n" +
" \"unit\": {\n" +
" \"type\": \"string\",\n" +
" \"enum\": [\n" +
" \"摄氏度\",\n" +
" \"华氏度\"\n" +
" ]\n" +
" }\n" +
" },\n" +
" \"required\": [\n" +
" \"location\"\n" +
" ]\n" +
" }\n" +
" }\n" +
"}";
request.setTools(toolJson);

AppBuilderClientIterator itor = builder.run(request);
assertTrue(itor.hasNext());
Expand All @@ -113,12 +107,8 @@ public void AppBuilderClientRunFuncTest() throws IOException, AppBuilderServerEx
System.out.println(result);
}

AppBuilderClientRunRequest request2 = new AppBuilderClientRunRequest();
request2.setAppId(appId);
request2.setConversationID(conversationId);

AppBuilderClientRunRequest.ToolOutput output = new AppBuilderClientRunRequest.ToolOutput(ToolCallID, "北京今天35度");
request2.setToolOutputs(new AppBuilderClientRunRequest.ToolOutput[] { output });
AppBuilderClientRunRequest request2 = new AppBuilderClientRunRequest(appId, conversationId);
request2.setToolOutputs(ToolCallID, "北京今天35度");
AppBuilderClientIterator itor2 = builder.run(request2);
assertTrue(itor2.hasNext());
while (itor2.hasNext()) {
Expand Down

0 comments on commit e6cc6f4

Please sign in to comment.