From 00da4997aede1aebcf6b752d27d09a1307dc78cd Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Fri, 29 Nov 2024 14:40:06 -0800 Subject: [PATCH] Add option to convert weights from the pytorch format. --- f5_tts_mlx/cfm.py | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index 2945fa1..05e1eab 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -9,6 +9,7 @@ from __future__ import annotations from datetime import datetime +import os from pathlib import Path from random import random from typing import Callable, Literal @@ -377,7 +378,7 @@ def fn(t, x): return out, trajectory @classmethod - def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS: + def from_pretrained(cls, hf_model_name_or_path: str, convert_weights = False) -> F5TTS: path = fetch_from_hub(hf_model_name_or_path) if path is None: @@ -435,6 +436,40 @@ def from_pretrained(cls, hf_model_name_or_path: str) -> F5TTS: ) weights = mx.load(model_path.as_posix(), format="safetensors") + + if convert_weights: + new_weights = {} + for k, v in weights.items(): + k = k.replace('ema_model.', '') + + # rename layers + if len(k) < 1 or 'mel_spec.' in k or k in ('initted', 'step'): + continue + elif '.to_out' in k: + k = k.replace('.to_out', '.to_out.layers') + elif '.text_blocks' in k: + k = k.replace('.text_blocks', '.text_blocks.layers') + elif '.ff.ff.0.0' in k: + k = k.replace('.ff.ff.0.0', '.ff.ff.layers.0.layers.0') + elif '.ff.ff.2' in k: + k = k.replace('.ff.ff.2', '.ff.ff.layers.2') + elif '.time_mlp' in k: + k = k.replace('.time_mlp', '.time_mlp.layers') + elif '.conv1d' in k: + k = k.replace('.conv1d', '.conv1d.layers') + + # reshape weights + if '.dwconv.weight' in k: + v = v.swapaxes(1, 2) + elif '.conv1d.layers.0.weight' in k: + v = v.swapaxes(1, 2) + elif '.conv1d.layers.2.weight' in k: + v = v.swapaxes(1, 2) + + new_weights[k] = v + + weights = new_weights + f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters())