Skip to content

Commit

Permalink
Add quantized model support.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Dec 13, 2024
1 parent 3cbb052 commit 7d6ab0a
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
30 changes: 24 additions & 6 deletions f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def __call__(
assert inp.shape[-1] == self.num_channels

batch, seq_len, dtype = *inp.shape[:2], inp.dtype

# handle text as string
if isinstance(text, list):
if exists(self._vocab_char_map):
Expand Down Expand Up @@ -230,7 +230,7 @@ def __call__(
drop_audio_cond = rand_audio_drop < self.audio_drop_prob
drop_text = rand_cond_drop < self.cond_drop_prob
drop_audio_cond = drop_audio_cond | drop_text

pred = self.transformer(
x=φ,
cond=cond,
Expand All @@ -241,7 +241,7 @@ def __call__(
)

# flow matching loss

loss = nn.losses.mse_loss(pred, flow, reduction="none")

rand_span_mask = repeat(rand_span_mask, "b n -> b n d", d=self.num_channels)
Expand Down Expand Up @@ -405,9 +405,14 @@ def fn(t, x):

@classmethod
def from_pretrained(
cls, hf_model_name_or_path: str, convert_weights=False
cls,
hf_model_name_or_path: str,
convert_weights = False,
quantization_bits: int | None = None,
) -> F5TTS:
path = fetch_from_hub(hf_model_name_or_path)
path = fetch_from_hub(
hf_model_name_or_path, quantization_bits=quantization_bits
)

if path is None:
raise ValueError(f"Could not find model {hf_model_name_or_path}")
Expand Down Expand Up @@ -446,7 +451,11 @@ def from_pretrained(

# model

model_path = path / "model.safetensors"
model_filename = "model.safetensors"
if exists(quantization_bits):
model_filename = f"model_{quantization_bits}b.safetensors"

model_path = path / model_filename

f5tts = F5TTS(
transformer=DiT(
Expand Down Expand Up @@ -498,6 +507,15 @@ def from_pretrained(

weights = new_weights

if quantization_bits is not None:
nn.quantize(
f5tts,
bits=quantization_bits,
class_predicate=lambda p, m: (
isinstance(m, nn.Linear) and m.weight.shape[1] % 64 == 0
),
)

f5tts.load_weights(list(weights.items()))
mx.eval(f5tts.parameters())

Expand Down
10 changes: 9 additions & 1 deletion f5_tts_mlx/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,15 @@ def generate(
sway_sampling_coef: float = -1.0,
speed: float = 1.0, # used when duration is None as part of the duration heuristic
seed: Optional[int] = None,
quantization_bits: Optional[int] = None,
output_path: Optional[str] = None,
):
player = AudioPlayer(sample_rate=SAMPLE_RATE) if output_path is None else None

# the default model already has converted weights
convert_weights = model_name != "lucasnewman/f5-tts-mlx"

f5tts = F5TTS.from_pretrained(model_name, convert_weights=convert_weights)
f5tts = F5TTS.from_pretrained(model_name, convert_weights=convert_weights, quantization_bits=quantization_bits)

if ref_audio_path is None:
data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav")
Expand Down Expand Up @@ -312,6 +313,12 @@ def generate(
default=None,
help="Seed for noise generation",
)
parser.add_argument(
"--q",
type=int,
default=None,
help="Number of bits to use for quantization. 4 and 8 are supported.",
)

args = parser.parse_args()

Expand All @@ -334,5 +341,6 @@ def generate(
sway_sampling_coef=args.sway_coef,
speed=args.speed,
seed=args.seed,
quantization_bits=args.q,
output_path=args.output,
)
11 changes: 9 additions & 2 deletions f5_tts_mlx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from __future__ import annotations
from pathlib import Path
from typing import Optional

import mlx.core as mx

Expand Down Expand Up @@ -182,11 +183,17 @@ def convert_char_to_pinyin(text_list, polyphone=True):
# fetch model from hub


def fetch_from_hub(hf_repo: str) -> Path:
def fetch_from_hub(hf_repo: str, quantization_bits: Optional[int] = None) -> Path:
model_filename = "model.safetensors"
if exists(quantization_bits):
model_filename = f"model_{quantization_bits}b.safetensors"

duration_predictor_path = "duration_v2.safetensors"

model_path = Path(
snapshot_download(
repo_id=hf_repo,
allow_patterns=["*.safetensors", "*.txt"],
allow_patterns=[model_filename, duration_predictor_path, "*.txt"],
)
)
return model_path

0 comments on commit 7d6ab0a

Please sign in to comment.