Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Whisper updates to allow HF models #923

Merged
merged 6 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Some more useful examples are listed below.

- Joint text and image embeddings with [CLIP](clip).
- Text generation from image and text inputs with [LLaVA](llava).
- Image segmentation with [Segment Anything (SAM)](segment_anything).

### Other Models

Expand Down
6 changes: 3 additions & 3 deletions llms/mlx_lm/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,15 @@ def hf_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
for token_id in range(self.vocab_size_base):
if token_id in self.added_tokens_ids:
continue
token_text = reverse_vocab[token_id].encode("utf-8")
token_text = reverse_vocab[token_id]
yield token_text, self.get_token_score(token_id), self.get_token_type(
token_id, token_text, self.special_ids
)

def get_token_type(
self, token_id: int, token_text: bytes, special_ids: Set[int]
) -> TokenType:
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text):
if re.fullmatch(rb"<0x[0-9A-Fa-f]{2}>", token_text.encode("utf-8")):
return TokenType.BYTE
return TokenType.CONTROL if token_id in special_ids else TokenType.NORMAL

Expand All @@ -84,7 +84,7 @@ def added_tokens(self) -> Iterable[Tuple[bytes, float, TokenType]]:
else:
toktype = TokenType.USER_DEFINED
score = -1000.0
yield text.encode("utf-8"), score, toktype
yield text, score, toktype

def has_newline_token(self):
return "<0x0A>" in self.tokenizer.vocab or "\n" in self.tokenizer.vocab
Expand Down
176 changes: 108 additions & 68 deletions whisper/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,88 @@ def available_models() -> List[str]:
return list(_MODELS.keys())


def hf_to_pt(weights, config):
config = {
"n_mels": config["num_mel_bins"],
"n_audio_ctx": config["max_source_positions"],
"n_audio_state": config["d_model"],
"n_audio_head": config["encoder_attention_heads"],
"n_audio_layer": config["encoder_layers"],
"n_vocab": config["vocab_size"],
"n_text_ctx": config["max_target_positions"],
"n_text_state": config["d_model"],
"n_text_head": config["decoder_attention_heads"],
"n_text_layer": config["decoder_layers"],
}

def remap(k):
k = k.replace("model.", "")
k = k.replace(".layers", ".blocks")
k = k.replace(".self_attn", ".attn")
k = k.replace(".attn_layer_norm", ".attn_ln")
k = k.replace(".encoder_attn.", ".cross_attn.")
k = k.replace(".encoder_attn_layer_norm", ".cross_attn_ln")
k = k.replace(".final_layer_norm", ".mlp_ln")
k = k.replace(".q_proj", ".query")
k = k.replace(".k_proj", ".key")
k = k.replace(".v_proj", ".value")
k = k.replace(".out_proj", ".out")
k = k.replace(".fc1", ".mlp1")
k = k.replace(".fc2", ".mlp2")
k = k.replace("embed_positions.weight", "positional_embedding")
k = k.replace("decoder.embed_tokens", "decoder.token_embedding")
k = k.replace("encoder.layer_norm", "encoder.ln_post")
k = k.replace("decoder.layer_norm", "decoder.ln")
return k

# token embeddings are shared with output projection
weights.pop("proj_out.weight", None)
weights = {remap(k): v for k, v in weights.items()}
return weights, config


def load_torch_weights_and_config(
name_or_path: str,
download_root: str = None,
):
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")

# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).exists():
# Try downloading from HF
from huggingface_hub import snapshot_download

name_or_path = snapshot_download(
repo_id=name_or_path,
allow_patterns=["*.json", "pytorch_model.bin", "*.txt"],
)
else:
raise RuntimeError(
f"Model {name_or_path} is not found in {available_models()},"
"on Hugging Face or as a local path."
)

if name_or_path.endswith(".pt"):
checkpoint = torch.load(name_or_path, map_location="cpu")
weights, config = checkpoint["model_state_dict"], checkpoint["dims"]
else:
name_or_path = Path(name_or_path)
weights = torch.load(
name_or_path / "pytorch_model.bin",
map_location="cpu",
)
with open(name_or_path / "config.json", "r") as fp:
config = json.load(fp)
weights, config = hf_to_pt(weights, config)

return weights, config, alignment_heads


def load_torch_model(
name_or_path: str,
download_root: str = None,
Expand All @@ -115,7 +197,8 @@ def load_torch_model(
Parameters
----------
name_or_path : str
one of the official model names listed by `whisper.available_models()` or a local Pytorch checkpoint which is in the original OpenAI format
one of the official model names listed by `whisper.available_models()` or
a local Pytorch checkpoint which is in the original OpenAI format
download_root: str
path to download the model files; by default, it uses "~/.cache/whisper"

Expand All @@ -128,82 +211,39 @@ def load_torch_model(
if download_root is None:
download_root = os.path.join(os.path.expanduser("~"), ".cache/whisper")

# todo: accept alignment_heads of local Pytorch checkpoint
alignment_heads = None
if name_or_path in _MODELS:
alignment_heads = _ALIGNMENT_HEADS[name_or_path]
name_or_path = _download(_MODELS[name_or_path], download_root)
elif not Path(name_or_path).is_file():
raise RuntimeError(
f"Model {name_or_path} is neither found in {available_models()} nor as a local path"
)

with open(name_or_path, "rb") as fp:
checkpoint = torch.load(fp)

dims = torch_whisper.ModelDimensions(**checkpoint["dims"])
weights, config, alignment_heads = load_torch_weights_and_config(
name_or_path, download_root
)
dims = torch_whisper.ModelDimensions(**config)
model = torch_whisper.Whisper(dims)
model.load_state_dict(checkpoint["model_state_dict"])
model.load_state_dict(weights)

if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return model


def convert(model, rules=None):
params = {}
if rules is not None and type(model) in rules:
out = rules[type(model)](model, rules)
return out
if isinstance(model, torch.Tensor):
return mx.array(model.detach().numpy())
if isinstance(model, torch.nn.ModuleList):
return [convert(n, rules) for n in model.children()]
if isinstance(model, torch.nn.Conv1d):
return {
"weight": convert(model.weight).transpose(0, 2, 1),
"bias": convert(model.bias),
}
for k, n in model.named_children():
if k in rules:
params.update(rules[k](n, rules))
else:
params[k] = convert(n, rules)
for k, p in model.named_parameters(recurse=False):
params[k] = convert(p)
return params


def torch_to_mlx(
torch_model: torch_whisper.Whisper,
dtype: mx.Dtype = mx.float16,
) -> Whisper:
def convert_rblock(model, rules):
children = dict(model.named_children())
mlp = list(children.pop("mlp").children())
params = {
"mlp1": convert(mlp[0], rules),
"mlp2": convert(mlp[-1], rules),
}
for k, n in children.items():
params[k] = convert(n, rules)
return params

rules = {
torch_whisper.ResidualAttentionBlock: convert_rblock,
}
def convert(name_or_path: str, dtype: mx.Dtype = mx.float16):
def remap(key, value):
key = key.replace("mlp.0", "mlp1")
key = key.replace("mlp.2", "mlp2")
if "conv" in key and value.ndim == 3:
value = value.swapaxes(1, 2)
return key, mx.array(value.detach()).astype(dtype)

params = convert(torch_model, rules)
weights, config, alignment_heads = load_torch_weights_and_config(name_or_path)
weights.pop("encoder.positional_embedding", None)
weights = dict(remap(k, v) for k, v in weights.items())

mlx_model = Whisper(torch_model.dims, dtype)
params = tree_map(lambda p: p.astype(dtype), params)
mlx_model.update(params)
model_dims = ModelDimensions(**config)
model = Whisper(model_dims, dtype)
model.load_weights(list(weights.items()), strict=False)

if (alignment_heads := getattr(torch_model, "alignment_heads", None)) is not None:
mlx_model.set_alignment_heads(alignment_heads.indices().T.numpy())
if alignment_heads is not None:
model.set_alignment_heads(alignment_heads)

return mlx_model
return model


def upload_to_hub(path: str, name: str, torch_name_or_path: str):
Expand Down Expand Up @@ -292,13 +332,13 @@ def quantize(weights, config, args):
action="store_true",
)
parser.add_argument(
"--q_group_size",
"--q-group-size",
help="Group size for quantization.",
type=int,
default=64,
)
parser.add_argument(
"--q_bits",
"--q-bits",
help="Bits per weight for quantization.",
type=int,
default=4,
Expand All @@ -318,7 +358,7 @@ def quantize(weights, config, args):
dtype = getattr(mx, args.dtype)

print("[INFO] Loading")
model = torch_to_mlx(load_torch_model(args.torch_name_or_path), dtype)
model = convert(args.torch_name_or_path, dtype)
config = asdict(model.dims)
weights = dict(tree_flatten(model.parameters()))

Expand Down
2 changes: 1 addition & 1 deletion whisper/mlx_whisper/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.

__version__ = "0.1.0"
__version__ = "0.2.0"
6 changes: 3 additions & 3 deletions whisper/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import mlx_whisper.load_models as load_models
import numpy as np
import torch
from convert import load_torch_model, quantize, torch_to_mlx
from convert import convert, load_torch_model, quantize
from mlx.utils import tree_flatten

MODEL_NAME = "tiny"
Expand Down Expand Up @@ -41,12 +41,12 @@ def _save_model(save_dir, weights, config):
def load_torch_and_mlx():
torch_model = load_torch_model(MODEL_NAME)

fp32_model = torch_to_mlx(torch_model, dtype=mx.float32)
fp32_model = convert(MODEL_NAME, dtype=mx.float32)
config = asdict(fp32_model.dims)
weights = dict(tree_flatten(fp32_model.parameters()))
_save_model(MLX_FP32_MODEL_PATH, weights, config)

fp16_model = torch_to_mlx(torch_model, dtype=mx.float16)
fp16_model = convert(MODEL_NAME, dtype=mx.float16)
config = asdict(fp16_model.dims)
weights = dict(tree_flatten(fp16_model.parameters()))
_save_model(MLX_FP16_MODEL_PATH, weights, config)
Expand Down