Skip to content

Commit

Permalink
Align the default values with the reference implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 14, 2024
1 parent a0952a5 commit 742a6f5
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 23 deletions.
42 changes: 31 additions & 11 deletions examples/generate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import datetime
import pkgutil
from typing import Optional

import mlx.core as mx
Expand All @@ -23,16 +24,30 @@ def generate(
generation_text: str,
duration: float,
model_name: str = "lucasnewman/f5-tts-mlx",
ref_audio_path: str = "tests/test_en_1_ref_short.wav",
ref_audio_text: str = "Some call me nature, others call me mother nature.",
sway_sampling_coef: float = 0.0,
ref_audio_path: Optional[str] = None,
ref_audio_text: Optional[str] = None,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
seed: Optional[int] = None,
output_path: str = "output.wav",
):
f5tts = F5TTS.from_pretrained(model_name)

# load reference audio
audio, sr = sf.read(ref_audio_path)
if ref_audio_path is None:
data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav")

# write to a temp file
tmp_ref_audio_file = "/tmp/ref.wav"
with open(tmp_ref_audio_file, "wb") as f:
f.write(data)

if data is not None:
audio, sr = sf.read(tmp_ref_audio_file)
ref_audio_text = "Some call me nature, others call me mother nature."
else:
# load reference audio
audio, sr = sf.read(ref_audio_path)

audio = mx.array(audio)
ref_audio_duration = audio.shape[0] / SAMPLE_RATE

Expand All @@ -54,7 +69,7 @@ def generate(
text=text,
duration=frame_duration,
steps=32,
cfg_strength=1,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
vocoder=vocos.decode,
Expand Down Expand Up @@ -93,13 +108,13 @@ def generate(
parser.add_argument(
"--ref-audio",
type=str,
default="tests/test_en_1_ref_short.wav",
default=None,
help="Path to the reference audio file",
)
parser.add_argument(
"--ref-text",
type=str,
default="Some call me nature, others call me mother nature.",
default=None,
help="Text spoken in the reference audio",
)
parser.add_argument(
Expand All @@ -108,14 +123,18 @@ def generate(
default="output.wav",
help="Path to save the generated audio output",
)

parser.add_argument(
"--cfg",
type=float,
default=2.0,
help="Strength of classifer free guidance",
)
parser.add_argument(
"--sway-coef",
type=float,
default="0.0",
default=-1.0,
help="Coefficient for sway sampling",
)

parser.add_argument(
"--seed",
type=int,
Expand All @@ -131,6 +150,7 @@ def generate(
model_name=args.model,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
cfg_strength=args.cfg,
sway_sampling_coef=args.sway_coef,
seed=args.seed,
output_path=args.output,
Expand Down
30 changes: 18 additions & 12 deletions f5_tts_mlx/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,27 +26,28 @@ def generate(
model_name: str = "lucasnewman/f5-tts-mlx",
ref_audio_path: Optional[str] = None,
ref_audio_text: Optional[str] = None,
sway_sampling_coef: float = 0.0,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
seed: Optional[int] = None,
output_path: str = "output.wav",
):
f5tts = F5TTS.from_pretrained(model_name)

if ref_audio_path is None:
data = pkgutil.get_data('f5_tts_mlx', 'tests/test_en_1_ref_short.wav')
data = pkgutil.get_data("f5_tts_mlx", "tests/test_en_1_ref_short.wav")

# write to a temp file
tmp_ref_audio_file = '/tmp/ref.wav'
with open(tmp_ref_audio_file, 'wb') as f:
tmp_ref_audio_file = "/tmp/ref.wav"
with open(tmp_ref_audio_file, "wb") as f:
f.write(data)

if data is not None:
audio, sr = sf.read(tmp_ref_audio_file)
ref_audio_text = "Some call me nature, others call me mother nature."
else:
# load reference audio
audio, sr = sf.read(ref_audio_path)

audio = mx.array(audio)
ref_audio_duration = audio.shape[0] / SAMPLE_RATE

Expand All @@ -68,7 +69,7 @@ def generate(
text=text,
duration=frame_duration,
steps=32,
cfg_strength=1,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
vocoder=vocos.decode,
Expand Down Expand Up @@ -122,14 +123,18 @@ def generate(
default="output.wav",
help="Path to save the generated audio output",
)

parser.add_argument(
"--cfg",
type=float,
default=2.0,
help="Strength of classifer free guidance",
)
parser.add_argument(
"--sway-coef",
type=float,
default="0.0",
default=-1.0,
help="Coefficient for sway sampling",
)

parser.add_argument(
"--seed",
type=int,
Expand All @@ -145,6 +150,7 @@ def generate(
model_name=args.model,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
cfg_strength=args.cfg,
sway_sampling_coef=args.sway_coef,
seed=args.seed,
output_path=args.output,
Expand Down

0 comments on commit 742a6f5

Please sign in to comment.