From b2b070c86b6af44990badcd01ec58f581dff8478 Mon Sep 17 00:00:00 2001 From: Lucas Newman Date: Mon, 21 Oct 2024 08:43:57 -0700 Subject: [PATCH] Add ODE steps as a parameter. --- examples/README.md | 27 ++++++++++++++++++++------- examples/generate.py | 10 +++++++++- f5_tts_mlx/cfm.py | 2 +- f5_tts_mlx/generate.py | 10 +++++++++- 4 files changed, 39 insertions(+), 10 deletions(-) diff --git a/examples/README.md b/examples/README.md index c7bdab1..19284a6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -17,14 +17,14 @@ Provide the text that you want to generate. ## Optional Parameters -`-–duration` +`--duration` float Specify the length of the generated audio in seconds. -`-–speed` +`--speed` float, default: 1.0 @@ -45,28 +45,41 @@ string, default: "tests/test_en_1_ref_short.wav" Provide a reference audio file path to help guide the generation. -`–-ref-text` +`--ref-text` string, default: "Some call me nature, others call me mother nature." Provide a caption for the reference audio. -`-–output` +`--output` string, default: "output.wav" Specify the output path where the generated audio will be saved. If not specified, the script will save the output to a default location. +`--cfg` -`-–sway-coef` +float, default: 2.0 -float, default: 0.0 +Specifies the strength used for classifier free guidance + + +`--steps` + +int, default: 32 + +Specify the number of steps used to sample the neural ODE. Lower steps trade off quality for latency. + + +`--sway-coef` + +float, default: -1.0 Set the sway sampling coefficient. The best values according to the paper are in the range of [-1.0...1.0]. -`-–seed` +`--seed` int, default: None (random) diff --git a/examples/generate.py b/examples/generate.py index 1272c4a..c98dd61 100644 --- a/examples/generate.py +++ b/examples/generate.py @@ -27,6 +27,7 @@ def generate( model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, + steps: int = 32, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, # used when duration is None as part of the duration heuristic @@ -83,7 +84,7 @@ def generate( mx.expand_dims(audio, axis=0), text=text, duration=frame_duration, - steps=32, + steps=steps, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, @@ -138,6 +139,12 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of steps to take when sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -171,6 +178,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + steps=args.steps, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed, diff --git a/f5_tts_mlx/cfm.py b/f5_tts_mlx/cfm.py index bd32c56..95ee68c 100644 --- a/f5_tts_mlx/cfm.py +++ b/f5_tts_mlx/cfm.py @@ -479,4 +479,4 @@ def from_pretrained( f5tts.load_weights(list(weights.items())) mx.eval(f5tts.parameters()) - return f5tts \ No newline at end of file + return f5tts diff --git a/f5_tts_mlx/generate.py b/f5_tts_mlx/generate.py index 1272c4a..c98dd61 100644 --- a/f5_tts_mlx/generate.py +++ b/f5_tts_mlx/generate.py @@ -27,6 +27,7 @@ def generate( model_name: str = "lucasnewman/f5-tts-mlx", ref_audio_path: Optional[str] = None, ref_audio_text: Optional[str] = None, + steps: int = 32, cfg_strength: float = 2.0, sway_sampling_coef: float = -1.0, speed: float = 1.0, # used when duration is None as part of the duration heuristic @@ -83,7 +84,7 @@ def generate( mx.expand_dims(audio, axis=0), text=text, duration=frame_duration, - steps=32, + steps=steps, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, seed=seed, @@ -138,6 +139,12 @@ def generate( default="output.wav", help="Path to save the generated audio output", ) + parser.add_argument( + "--steps", + type=int, + default=32, + help="Number of steps to take when sampling the neural ODE", + ) parser.add_argument( "--cfg", type=float, @@ -171,6 +178,7 @@ def generate( model_name=args.model, ref_audio_path=args.ref_audio, ref_audio_text=args.ref_text, + steps=args.steps, cfg_strength=args.cfg, sway_sampling_coef=args.sway_coef, speed=args.speed,