Skip to content

willccbb/mlx_parallm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MLX ParaLLM

Batched KV caching for fast parallel inference on Apple Silicon devices, via MLX.

This repo heavily borrows from mlx_lm. Will explore how to add batched generation there as a non-breaking PR.

Usage

Requires mlx and mlx_lm to be installed.

from mlx_parallm.utils import load, batch_generate
model, tokenizer = load("google/gemma-1.1-2b-it")
prompts = ["prompt_0", ..., "prompt_k"]
responses = batch_generate(model, tokenizer, prompts=prompts_raw[:10], max_tokens=100, verbose=True, format_prompts=True, temp=0.0)

Models

Models tested:

  • meta-llama/Meta-Llama-3-8B-Instruct
  • microsoft/Phi-3-mini-4k-instruct
  • google/gemma-1.1-2b-it
  • mlx-community/Meta-Llama-3-8B-Instruct-4bit
  • mlx-community/Phi-3-mini-4k-instruct-4bit
  • mlx-community/gemma-1.1-2b-it-4bit

Both quantized and float16 models are supported. float16 models seem to generally perform faster if sufficient RAM is available (up to 1300+ tok/s throughput for gemma-2b on M3 Max 128GB).

Additional models can be added by copying architecture files from mlx_lm/models and replacing any references to KVCache with BatchedKVCache.

Features

Supported:

  • batch_generate method (tested with len(prompts) > 500)
  • Auto-padding
  • Auto-formatting with prompt templates (format_prompts=True)
  • temp = 0, temp > 0, top_p sampling
  • single-stream generate method

Not (yet) supported:

  • Repetition penalties
  • Streaming outputs for batch_generate
  • Dynamic batching for async requests

About

Fast parallel LLM inference for MLX

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published