Skip to content

Commit

Permalink
FEAT: ChatGLM3 tool calls (#701)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Dec 1, 2023
1 parent 505c3ef commit 909a428
Show file tree
Hide file tree
Showing 14 changed files with 674 additions and 125 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ jobs:
MODULE: ${{ matrix.module }}
run: |
if [ "$MODULE" == "gpu" ]; then
${{ env.SELF_HOST_PYTHON }} -m pip install "openai>1"
${{ env.SELF_HOST_PYTHON }} -m pip install -U modelscope
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
Expand Down
200 changes: 200 additions & 0 deletions examples/FunctionCall.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Xinference is a powerful tool that supports function calling and is fully compatible with OpenAI's Tool Call API. It allows you to seamlessly integrate and utilize the functionality of OpenAI's tools within your own projects."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Preparation\n",
"\n",
"First, you need to install Xinference:\n",
"```shell\n",
"pip install xinference\n",
"```\n",
"\n",
"Currently, only ChatGLM3 model is available for tool calls. Here, we install a chatglm cpp for fast inference on Mac. If you are not using a Mac, please follow the install instruction: https://github.com/li-plus/chatglm.cpp#python-binding\n",
"\n",
"```shell\n",
"CMAKE_ARGS=\"-DGGML_METAL=ON\" pip install -U chatglm-cpp --no-cache-dir\n",
"```\n",
"\n",
"Then, start the Xinference server by the following command:\n",
"```shell\n",
"xinference\n",
"```\n",
"\n",
"The Xinference server will be started:\n",
"\n",
"```shell\n",
"2023-11-02 16:04:55,278 xinference 38878 INFO Xinference successfully started. Endpoint: http://127.0.0.1:9997\n",
"2023-11-02 16:04:55,280 xinference.core.supervisor 38878 INFO Worker 127.0.0.1:32187 has been added successfully\n",
"2023-11-02 16:04:55,281 xinference.deploy.worker 38878 INFO Xinference worker successfully started.\n",
"```\n",
"\n",
"Finally, we launch a ChatGLM3 model for tool calls.\n",
"```shell\n",
"xinference launch -u my_tool_model -n chatglm3 -f ggmlv3 -q q4_0\n",
"```\n",
"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Function calling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is an exciting game that showcases the function calling feature. Our objective is to call the AI and obtain a set of lottery numbers to determine if we are lucky enough to win the lottery. To get started, we first need to define a simple function for generating lottery numbers."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"\n",
"list_red = list(range(1, 34))\n",
"list_blue = list(range(1, 17))\n",
"\n",
"\n",
"def get_lucky_lottery(num):\n",
" total = []\n",
" for _ in range(num):\n",
" res_red = random.sample(list_red, 6)\n",
" res_blue = random.sample(list_blue, 1)\n",
" total.append(res_red + res_blue)\n",
" return total"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Then, we describe two functions as a dict, more details refers to https://platform.openai.com/docs/api-reference/chat/create#chat-create-tools"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"tools = [\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_lucky_lottery\",\n",
" \"description\": \"生成双色球彩票号码\",\n",
" \"parameters\": {\n",
" \"type\": \"int\",\n",
" \"properties\": {\"num\": {\"description\": \"生成的彩票号码组数\"}},\n",
" \"required\": [\"num\"],\n",
" },\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"lottery_draw\",\n",
" \"description\": \"开奖双色球\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {},\n",
" },\n",
" },\n",
" },\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Define a function to call above two functions according to the response."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import functools\n",
"import json\n",
"\n",
"FUNCTION_TABLE = {\n",
" \"get_lucky_lottery\": get_lucky_lottery,\n",
" \"lottery_draw\": functools.partial(get_lucky_lottery, 1),\n",
"}\n",
"\n",
"\n",
"def handle_response(completion):\n",
" assert \"tool_calls\" == completion.choices[0].finish_reason\n",
" func_name = completion.choices[0].message.tool_calls[0].function.name\n",
" func_args = completion.choices[0].message.tool_calls[0].function.arguments\n",
" func_kwargs = json.loads(func_args)\n",
" return FUNCTION_TABLE.get(func_name)(**func_kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's ask the LLM to generate 5 group lottery numbers and check if we win the game."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import openai\n",
"\n",
"client = openai.Client(api_key=\"not empty\", base_url=f\"http://127.0.0.1:9997/v1\")\n",
"completion = client.chat.completions.create(\n",
" model=\"my_tool_model\",\n",
" messages=[{\"role\": \"user\", \"content\": \"帮我生成5组双色球号码\"}],\n",
" tools=tools,\n",
")\n",
"print(f\"Lottery numbers: {handle_response(completion)}\")\n",
"completion = client.chat.completions.create(\n",
" model=\"my_tool_model\", messages=[{\"role\": \"user\", \"content\": \"开奖\"}], tools=tools\n",
")\n",
"print(f\"Lottery draw: {handle_response(completion)}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"No lucky, I lose the game.\n",
"\n",
"```\n",
"Lottery numbers: [[30, 9, 22, 26, 17, 27, 14], [8, 6, 27, 30, 21, 20, 13], [4, 33, 9, 32, 27, 22, 14], [19, 1, 30, 4, 28, 13, 7], [16, 7, 23, 17, 8, 30, 12]]\n",
"Lottery draw: [[29, 10, 27, 30, 15, 6, 1]]\n",
"```"
]
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
3 changes: 3 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ install_requires =
s3fs
modelscope
sse_starlette
openai>1 # For typing

[options.packages.find]
exclude =
Expand Down Expand Up @@ -70,6 +71,7 @@ dev =
orjson
sphinx-tabs
all =
chatglm-cpp>=0.3.0
ctransformers
llama-cpp-python>=0.2.0
transformers>=4.34.1
Expand All @@ -91,6 +93,7 @@ all =
ggml =
llama-cpp-python>=0.2.0
ctransformers
chatglm-cpp>=0.3.0
transformers =
transformers>=4.34.1
torch
Expand Down
83 changes: 21 additions & 62 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import pprint
import sys
import warnings
from typing import Any, Dict, List, Literal, Optional, Union
from typing import Any, List, Optional, Union

import gradio as gr
import xoscar as xo
Expand All @@ -43,28 +43,19 @@
from sse_starlette.sse import EventSourceResponse
from starlette.responses import JSONResponse as StarletteJSONResponse
from starlette.responses import RedirectResponse
from typing_extensions import NotRequired, TypedDict
from uvicorn import Config, Server
from xoscar.utils import get_next_port

from ..constants import XINFERENCE_DEFAULT_ENDPOINT_PORT
from ..core.supervisor import SupervisorActor
from ..core.utils import json_dumps
from ..fields import (
frequency_penalty_field,
max_tokens_field,
mirostat_eta_field,
mirostat_mode_field,
mirostat_tau_field,
presence_penalty_field,
repeat_penalty_field,
stop_field,
stream_field,
temperature_field,
top_k_field,
top_p_field,
from ..types import (
ChatCompletion,
Completion,
CreateChatCompletion,
CreateCompletion,
ImageList,
)
from ..types import ChatCompletion, Completion, CreateCompletion, ImageList

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -115,50 +106,6 @@ class TextToImageRequest(BaseModel):
user: Optional[str] = None


class ChatCompletionRequestMessage(TypedDict):
role: Literal["assistant", "user", "system"]
content: str
user: NotRequired[str]


class CreateChatCompletionRequest(BaseModel):
messages: List[ChatCompletionRequestMessage] = Field(
default=[], description="A list of messages to generate completions for."
)
max_tokens: int = max_tokens_field
temperature: float = temperature_field
top_p: float = top_p_field
mirostat_mode: int = mirostat_mode_field
mirostat_tau: float = mirostat_tau_field
mirostat_eta: float = mirostat_eta_field
stop: Optional[Union[str, List[str]]] = stop_field
stream: bool = stream_field
presence_penalty: Optional[float] = presence_penalty_field
frequency_penalty: Optional[float] = frequency_penalty_field
logit_bias: Optional[Dict[str, float]] = Field(None)

model: str
n: Optional[int] = 1
user: Optional[str] = Field(None)

# llama.cpp specific parameters
top_k: int = top_k_field
repeat_penalty: Optional[float] = repeat_penalty_field
logit_bias_type: Optional[Literal["input_ids", "tokens"]] = Field(None)
grammar: Optional[str] = Field(None)

class Config:
schema_extra = {
"example": {
"messages": [
{"role": "system", "content": "you are a helpful AI assistant"},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi what can I help you?"},
]
}
}


class RegisterModelRequest(BaseModel):
model: str
persist: bool
Expand Down Expand Up @@ -308,6 +255,9 @@ def serve(self, logging_conf: Optional[dict] = None):
f"{pprint.pformat(invalid_routes)}"
)

for tp in [CreateChatCompletion, CreateCompletion]:
logger.debug("Dump request model fields:\n%s", tp.__fields__)

class SPAStaticFiles(StaticFiles):
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
Expand Down Expand Up @@ -739,7 +689,7 @@ async def create_variations(
async def create_chat_completion(
self,
request: Request,
body: CreateChatCompletionRequest,
body: CreateChatCompletion,
) -> Response:
exclude = {
"prompt",
Expand All @@ -750,7 +700,7 @@ async def create_chat_completion(
"logit_bias_type",
"user",
}
kwargs = body.dict(exclude=exclude)
kwargs = body.dict(exclude_unset=True, exclude=exclude)

if body.logit_bias is not None:
raise HTTPException(status_code=501, detail="Not implemented")
Expand Down Expand Up @@ -809,6 +759,7 @@ async def create_chat_completion(
is_chatglm_ggml = desc.get(
"model_format"
) == "ggmlv3" and "chatglm" in desc.get("model_name", "")
is_chatglm3 = "chatglm3" == desc.get("model_name", "")

is_qwen = desc.get("model_format") == "ggmlv3" and "qwen" in desc.get(
"model_name", ""
Expand All @@ -818,6 +769,14 @@ async def create_chat_completion(
raise HTTPException(
status_code=400, detail="ChatGLM ggml does not have system prompt"
)
if is_chatglm3 and body.tools and body.stream:
raise HTTPException(
status_code=400, detail="ChatGLM3 tool calls does not support stream"
)
if body.tools and not is_chatglm3:
raise HTTPException(
status_code=400, detail="Only ChatGLM3 support tool calls"
)

if body.stream:

Expand Down
Loading

0 comments on commit 909a428

Please sign in to comment.