Skip to content

Commit

Permalink
Add support for tool calling (#21)
Browse files Browse the repository at this point in the history
* refactor vlm_stream_generate

* refactor custom types

* Add llama 3.1 tool calling

* fix opening config file

* add xlam tool calling

* add arcee-agent and firefunction

* bump mlx-vlm

* add function call tets

* fix vlm call

* add function calling

* add command-r-plus

* sort
  • Loading branch information
Blaizzy authored Aug 6, 2024
1 parent 61b495e commit 6361d75
Show file tree
Hide file tree
Showing 12 changed files with 761 additions and 167 deletions.
84 changes: 75 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,15 @@

Start the FastMLX server:
```bash
fastmlx
fastmlx
```
or

```bash
uvicorn fastmlx:app --reload --workers 0
```

> [!WARNING]
> [!WARNING]
> The `--reload` flag should not be used in production. It is only intended for development purposes.
### Running with Multiple Workers (Parallel Processing)
Expand All @@ -49,7 +49,7 @@
You can also set the `FASTMLX_NUM_WORKERS` environment variable to specify the number of workers or the fraction of CPU cores to use. `workers` defaults to 2 if not passed explicitly or set via the environment variable.

In order of precedence (highest to lowest), the number of workers is determined by the following:
- Explicitly passed as a command-line argument
- Explicitly passed as a command-line argument
- `--workers 4` will set the number of workers to 4
- `--workers 0.5` will set the number of workers to half the number of CPU cores available (minimum of 1)
- Set via the `FASTMLX_NUM_WORKERS` environment variable
Expand All @@ -59,7 +59,7 @@

Example:
```bash
fastmlx --workers 4
fastmlx --workers 4
```
or

Expand All @@ -68,7 +68,7 @@
```

> [!NOTE]
> - `--reload` flag is not compatible with multiple workers
> - `--reload` flag is not compatible with multiple workers
> - The number of workers should typically not exceed the number of CPU cores available on your machine for optimal performance.
### Considerations for Multi-Worker Setup
Expand Down Expand Up @@ -222,7 +222,73 @@
process_sse_stream(url, headers, data)
```

4. **Listing Available Models**
4. **Function Calling**

FastMLX now supports tool calling in accordance with the OpenAI API specification. This feature is available for the following models:

- Llama 3.1
- Arcee Agent
- C4ai-Command-R-Plus
- Firefunction
- xLAM

Supported modes:
- Without Streaming
- Parallel Tool Calling

> Note: Tool choice and OpenAI-compliant streaming for function calling are currently under development.
Here's an example of how to use function calling with FastMLX:

```python
import requests
import json

url = "http://localhost:8000/v1/chat/completions"
headers = {"Content-Type": "application/json"}
data = {
"model": "mlx-community/Meta-Llama-3.1-8B-Instruct-8bit",
"messages": [
{
"role": "user",
"content": "What's the weather like in San Francisco and Washington?"
}
],
"tools": [
{
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this from the user's location."
}
},
"required": ["location", "format"]
}
}
],
"max_tokens": 150,
"temperature": 0.7,
"stream": False,
}

response = requests.post(url, headers=headers, data=json.dumps(data))
print(response.json())
```

This example demonstrates how to use the `get_current_weather` tool with the Llama 3.1 model. The API will process the user's question and use the provided tool to fetch the required information.

Please note that while streaming is available for regular text generation, the streaming implementation for function calling is still in development and does not yet fully comply with the OpenAI specification.

5. **Listing Available Models**

To see all vision and language models supported by MLX:

Expand All @@ -234,7 +300,7 @@
print(response.json())
```

5. **List Available Models**
6. **List Available Models**

You can add new models to the API:

Expand All @@ -250,7 +316,7 @@
print(response.json())
```

6. **Listing Available Models**
7. **Listing Available Models**

To see all available models:

Expand All @@ -262,7 +328,7 @@
print(response.json())
```

7. **Delete Models**
8. **Delete Models**

To remove any models loaded to memory:

Expand Down
135 changes: 56 additions & 79 deletions fastmlx/fastmlx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,33 @@
import argparse
import asyncio
import os
import time
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List
from urllib.parse import unquote

from fastapi import FastAPI, HTTPException, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, Field

from .types.chat.chat_completion import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatMessage,
)
from .types.model import SupportedModels

try:
import mlx.core as mx
from mlx_lm import generate as lm_generate
from mlx_vlm import generate as vlm_generate
from mlx_vlm.prompt_utils import get_message_json
from mlx_vlm.prompt_utils import apply_chat_template as apply_vlm_chat_template
from mlx_vlm.utils import load_config

from .utils import (
MODEL_REMAPPING,
MODELS,
SupportedModels,
apply_lm_chat_template,
get_eom_token,
get_tool_prompt,
handle_function_calls,
lm_generate,
lm_stream_generator,
load_lm_model,
load_vlm_model,
Expand Down Expand Up @@ -63,28 +70,6 @@ async def get_available_models(self):
return list(self.models.keys())


class ChatMessage(BaseModel):
role: str
content: str


class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
image: Optional[str] = Field(default=None)
max_tokens: Optional[int] = Field(default=100)
stream: Optional[bool] = Field(default=False)
temperature: Optional[float] = Field(default=0.2)


class ChatCompletionResponse(BaseModel):
id: str
object: str = "chat.completion"
created: int
model: str
choices: List[dict]


app = FastAPI()


Expand Down Expand Up @@ -134,6 +119,7 @@ async def chat_completion(request: ChatCompletionRequest):
model = model_data["model"]
config = model_data["config"]
model_type = MODEL_REMAPPING.get(config["model_type"], config["model_type"])
stop_words = get_eom_token(request.model)

if model_type in MODELS["vlm"]:
processor = model_data["processor"]
Expand All @@ -145,29 +131,15 @@ async def chat_completion(request: ChatCompletionRequest):

for msg in request.messages:
if msg.role == "user":
chat_messages.append(
get_message_json(config["model_type"], msg.content)
)
chat_messages.append(msg.content)
else:
chat_messages.append({"role": msg.role, "content": msg.content})

prompt = ""
if "chat_template" in processor.__dict__.keys():
prompt = processor.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True,
)

elif "tokenizer" in processor.__dict__.keys():
if model.config.model_type != "paligemma":
prompt = processor.tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True,
)
else:
prompt = request.messages[-1].content
if model.config.model_type != "paligemma":
prompt = apply_vlm_chat_template(processor, config, chat_messages)
else:
prompt = request.messages[-1].content

if stream:
return StreamingResponse(
Expand Down Expand Up @@ -197,20 +169,33 @@ async def chat_completion(request: ChatCompletionRequest):
)

else:
# Add function calling information to the prompt
if request.tools and "firefunction-v2" not in request.model:
# Handle system prompt
if request.messages and request.messages[0].role == "system":
pass
else:
# Generate system prompt based on model and tools
prompt, user_role = get_tool_prompt(
request.model,
[tool.model_dump() for tool in request.tools],
request.messages[-1].content,
)

if user_role:
request.messages[-1].content = prompt
else:
# Insert the system prompt at the beginning of the messages
request.messages.insert(
0, ChatMessage(role="system", content=prompt)
)

tokenizer = model_data["tokenizer"]

chat_messages = [
{"role": msg.role, "content": msg.content} for msg in request.messages
]
if tokenizer.chat_template is not None and hasattr(
tokenizer, "apply_chat_template"
):
prompt = tokenizer.apply_chat_template(
chat_messages,
tokenize=False,
add_generation_prompt=True,
)
else:
prompt = request.messages[-1].content
prompt = apply_lm_chat_template(tokenizer, chat_messages, request)

if stream:
return StreamingResponse(
Expand All @@ -221,29 +206,22 @@ async def chat_completion(request: ChatCompletionRequest):
prompt,
request.max_tokens,
request.temperature,
stop_words=stop_words,
),
media_type="text/event-stream",
)
else:
output = lm_generate(
model, tokenizer, prompt, request.max_tokens, False, request.temperature
model,
tokenizer,
prompt,
request.max_tokens,
temp=request.temperature,
stop_words=stop_words,
)

# Prepare the response
response = ChatCompletionResponse(
id=f"chatcmpl-{os.urandom(4).hex()}",
created=int(time.time()),
model=request.model,
choices=[
{
"index": 0,
"message": {"role": "assistant", "content": output},
"finish_reason": "stop",
}
],
)

return response
# Parse the output to check for function calls
return handle_function_calls(output, request)


@app.get("/v1/supported_models", response_model=SupportedModels)
Expand Down Expand Up @@ -299,14 +277,14 @@ def run():
parser.add_argument(
"--workers",
type=int_or_float,
default=calculate_default_workers,
help="""Number of workers. Overrides the `FASTMLX_NUM_WORKERS` env variable.
Can be either an int or a float.
default=calculate_default_workers(),
help="""Number of workers. Overrides the `FASTMLX_NUM_WORKERS` env variable.
Can be either an int or a float.
If an int, it will be the number of workers to use.
If a float, number of workers will be this fraction of the number of CPU cores available, with a minimum of 1.
Defaults to the `FASTMLX_NUM_WORKERS` env variable if set and to 2 if not.
To use all available CPU cores, set it to 1.0.
Examples:
--workers 1 (will use 1 worker)
--workers 1.0 (will use all available CPU cores)
Expand All @@ -315,7 +293,6 @@ def run():
)

args = parser.parse_args()

if isinstance(args.workers, float):
args.workers = max(1, int(os.cpu_count() * args.workers))

Expand Down
18 changes: 18 additions & 0 deletions fastmlx/tools/arcee_agent.j2
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
{% if tools %}
In this environment, you have access to a set of tools you can use to answer the user's question.

You may call them like this:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>

Here are the tools available:
{{ tools }}

{% endif %}
Loading

0 comments on commit 6361d75

Please sign in to comment.