From 742a6f5859e23ca844d6ddc211feb9d0581ae0f9 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 14 Oct 2024 15:31:03 -0700 Subject: [PATCH] Align the default values with the reference implementation. --- examples/generate.py | 42 +++++++++++++++++++++++++++++++----------- f5_tts_mlx/generate.py | 30 ++++++++++++++++++------------ 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/examples/generate.py b/examples/generate.py index 04e5887..f3e935f 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -1,5 +1,6 @@ import argparse import datetime +import pkgutil from typing import Optional import mlx.core as mx @@ -23,16 +24,30 @@ def generate( generation_text: str, duration: float, model_name: str = "lucasnewman/f5-tts-mlx", - ref_audio_path: str = "tests/test_en_1_ref_short.wav", - ref_audio_text: str = "Some call me nature, others call me mother nature.", - sway_sampling_coef: float = 0.0, + ref_audio_path: Optional[str] = None, + ref_audio_text: Optional[str] = None, + cfg_strength: float = 2.0, + sway_sampling_coef: float = -1.0, seed: Optional[int] = None, output_path: str = "output.wav", ): f5tts = F5TTS.from_pretrained(model_name) - # load reference audio - audio, sr = sf.read(ref_audio_path) + if ref_audio_path is None: + data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav") + + # write to a temp file + tmp_ref_audio_file = "/tmp/ref.wav" + with open(tmp_ref_audio_file, "wb") as f: + f.write(data) + + if data is not None: + audio, sr = sf.read(tmp_ref_audio_file) + ref_audio_text = "Some call me nature, others call me mother nature." + else: + # load reference audio + audio, sr = sf.read(ref_audio_path) + audio = mx.array(audio) ref_audio_duration = audio.shape[0] / SAMPLE_RATE @@ -54,7 +69,7 @@ def generate( text=text, duration=frame_duration, steps=32, - cfg_strength=1, + cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, vocoder=vocos.decode, @@ -93,13 +108,13 @@ def generate( parser.add_argument( "--ref-audio", type=str, - default="tests/test_en_1_ref_short.wav", + default=None, help="Path to the reference audio file", ) parser.add_argument( "--ref-text", type=str, - default="Some call me nature, others call me mother nature.", + default=None, help="Text spoken in the reference audio", ) parser.add_argument( @@ -108,14 +123,18 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) - + parser.add_argument( + "--cfg", + type=float, + default=2.0, + help="Strength of classifer free guidance", + ) parser.add_argument( "--sway-coef", type=float, - default="0.0", + default=-1.0, help="Coefficient for sway sampling", ) - parser.add_argument( "--seed", type=int, @@ -131,6 +150,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, seed=args.seed, output_path=args.output, diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index 15ab0f9..f3e935f 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -26,27 +26,28 @@ def generate( model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, - sway_sampling_coef: float = 0.0, + cfg_strength: float = 2.0, + sway_sampling_coef: float = -1.0, seed: Optional[int] = None, output_path: str = "output.wav", ): f5tts = F5TTS.from_pretrained(model_name) - + if ref_audio_path is None: - data = pkgutil.get_data('f5_tts_mlx', 'tests/test_en_1_ref_short.wav') - + data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav") + # write to a temp file - tmp_ref_audio_file = '/tmp/ref.wav' - with open(tmp_ref_audio_file, 'wb') as f: + tmp_ref_audio_file = "/tmp/ref.wav" + with open(tmp_ref_audio_file, "wb") as f: f.write(data) - + if data is not None: audio, sr = sf.read(tmp_ref_audio_file) ref_audio_text = "Some call me nature, others call me mother nature." else: # load reference audio audio, sr = sf.read(ref_audio_path) - + audio = mx.array(audio) ref_audio_duration = audio.shape[0] / SAMPLE_RATE @@ -68,7 +69,7 @@ def generate( text=text, duration=frame_duration, steps=32, - cfg_strength=1, + cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, vocoder=vocos.decode, @@ -122,14 +123,18 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) - + parser.add_argument( + "--cfg", + type=float, + default=2.0, + help="Strength of classifer free guidance", + ) parser.add_argument( "--sway-coef", type=float, - default="0.0", + default=-1.0, help="Coefficient for sway sampling", ) - parser.add_argument( "--seed", type=int, @@ -145,6 +150,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, seed=args.seed, output_path=args.output,