Skip to content

Commit

Permalink
Convert weights for provided models automatically.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Nov 30, 2024
1 parent d13850f commit bb672ab
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 8 deletions.
5 changes: 4 additions & 1 deletion f5_tts_mlx/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ def generate(
seed: Optional[int] = None,
output_path: str = "output.wav",
):
f5tts = F5TTS.from_pretrained(model_name)
# the default model already has converted weights
convert_weights = model_name != "lucasnewman/f5-tts-mlx"

f5tts = F5TTS.from_pretrained(model_name, convert_weights=convert_weights)

if ref_audio_path is None:
data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav")
Expand Down
10 changes: 4 additions & 6 deletions f5_tts_mlx/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ def __init__(
):
super().__init__()
base *= base_rescale_factor ** (dim / (dim - 2))

inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))
self.inv_freq = inv_freq
self.inv_freq = 1.0 / (base ** (mx.arange(0, dim, 2).astype(mx.float32) / dim))

assert interpolation_factor >= 1.0
self.interpolation_factor = interpolation_factor
Expand Down Expand Up @@ -73,8 +71,8 @@ def precompute_freqs_cis(
freqs = 1.0 / (
theta ** (mx.arange(0, dim, 2)[: (dim // 2)].astype(mx.float32) / dim)
)
t = mx.arange(end) # type: ignore
freqs = mx.outer(t, freqs).astype(mx.float32) # type: ignore
t = mx.arange(end)
freqs = mx.outer(t, freqs).astype(mx.float32)
freqs_cos = freqs.cos() # real part
freqs_sin = freqs.sin() # imaginary part
return mx.concatenate([freqs_cos, freqs_sin], axis=-1)
Expand Down Expand Up @@ -522,7 +520,7 @@ def __call__(

if attn_mask is not None:
mask = rearrange(mask, "b n -> b n 1")
x = x.masked_fill(~mask, 0.0)
x = mx.where(mask, x, 0.0)

return x

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "f5-tts-mlx"
version = "0.1.8"
version = "0.1.9"
authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}]
license = {text = "MIT"}
description = "F5-TTS - MLX"
Expand Down

0 comments on commit bb672ab

Please sign in to comment.