Skip to content

Commit

Permalink
Add option to convert weights from the pytorch format.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Nov 29, 2024
1 parent 52c940a commit 00da499
Showing 1 changed file with 36 additions and 1 deletion.
37 changes: 36 additions & 1 deletion f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations
from datetime import datetime
import os
from pathlib import Path
from random import random
from typing import Callable, Literal
Expand Down Expand Up @@ -377,7 +378,7 @@ def fn(t, x):
return out, trajectory

@classmethod
def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS:
def from_pretrained(cls, hf_model_name_or_path: str, convert_weights = False) -> F5TTS:
path = fetch_from_hub(hf_model_name_or_path)

if path is None:
Expand Down Expand Up @@ -435,6 +436,40 @@ def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS:
)

weights = mx.load(model_path.as_posix(), format="safetensors")

if convert_weights:
new_weights = {}
for k, v in weights.items():
k = k.replace('ema_model.', '')

# rename layers
if len(k) < 1 or 'mel_spec.' in k or k in ('initted', 'step'):
continue
elif '.to_out' in k:
k = k.replace('.to_out', '.to_out.layers')
elif '.text_blocks' in k:
k = k.replace('.text_blocks', '.text_blocks.layers')
elif '.ff.ff.0.0' in k:
k = k.replace('.ff.ff.0.0', '.ff.ff.layers.0.layers.0')
elif '.ff.ff.2' in k:
k = k.replace('.ff.ff.2', '.ff.ff.layers.2')
elif '.time_mlp' in k:
k = k.replace('.time_mlp', '.time_mlp.layers')
elif '.conv1d' in k:
k = k.replace('.conv1d', '.conv1d.layers')

# reshape weights
if '.dwconv.weight' in k:
v = v.swapaxes(1, 2)
elif '.conv1d.layers.0.weight' in k:
v = v.swapaxes(1, 2)
elif '.conv1d.layers.2.weight' in k:
v = v.swapaxes(1, 2)

new_weights[k] = v

weights = new_weights

f5tts.load_weights(list(weights.items()))
mx.eval(f5tts.parameters())

Expand Down

0 comments on commit 00da499

Please sign in to comment.