Skip to content

Commit

Permalink
Add ODE steps as a parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 21, 2024
1 parent bd6f4d8 commit b2b070c
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 10 deletions.
27 changes: 20 additions & 7 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@ Provide the text that you want to generate.

## Optional Parameters

`-duration`
`--duration`

float

Specify the length of the generated audio in seconds.


`-speed`
`--speed`

float, default: 1.0

Expand All @@ -45,28 +45,41 @@ string, default: "tests/test_en_1_ref_short.wav"
Provide a reference audio file path to help guide the generation.


`-ref-text`
`--ref-text`

string, default: "Some call me nature, others call me mother nature."

Provide a caption for the reference audio.


`-output`
`--output`

string, default: "output.wav"

Specify the output path where the generated audio will be saved. If not specified, the script will save the output to a default location.

`--cfg`

`-–sway-coef`
float, default: 2.0

float, default: 0.0
Specifies the strength used for classifier free guidance


`--steps`

int, default: 32

Specify the number of steps used to sample the neural ODE. Lower steps trade off quality for latency.


`--sway-coef`

float, default: -1.0

Set the sway sampling coefficient. The best values according to the paper are in the range of [-1.0...1.0].


`-seed`
`--seed`

int, default: None (random)

Expand Down
10 changes: 9 additions & 1 deletion examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def generate(
model_name: str = "lucasnewman/f5-tts-mlx",
ref_audio_path: Optional[str] = None,
ref_audio_text: Optional[str] = None,
steps: int = 32,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
speed: float = 1.0, # used when duration is None as part of the duration heuristic
Expand Down Expand Up @@ -83,7 +84,7 @@ def generate(
mx.expand_dims(audio, axis=0),
text=text,
duration=frame_duration,
steps=32,
steps=steps,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
Expand Down Expand Up @@ -138,6 +139,12 @@ def generate(
default="output.wav",
help="Path to save the generated audio output",
)
parser.add_argument(
"--steps",
type=int,
default=32,
help="Number of steps to take when sampling the neural ODE",
)
parser.add_argument(
"--cfg",
type=float,
Expand Down Expand Up @@ -171,6 +178,7 @@ def generate(
model_name=args.model,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
steps=args.steps,
cfg_strength=args.cfg,
sway_sampling_coef=args.sway_coef,
speed=args.speed,
Expand Down
2 changes: 1 addition & 1 deletion f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,4 +479,4 @@ def from_pretrained(
f5tts.load_weights(list(weights.items()))
mx.eval(f5tts.parameters())

return f5tts
return f5tts
10 changes: 9 additions & 1 deletion f5_tts_mlx/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def generate(
model_name: str = "lucasnewman/f5-tts-mlx",
ref_audio_path: Optional[str] = None,
ref_audio_text: Optional[str] = None,
steps: int = 32,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
speed: float = 1.0, # used when duration is None as part of the duration heuristic
Expand Down Expand Up @@ -83,7 +84,7 @@ def generate(
mx.expand_dims(audio, axis=0),
text=text,
duration=frame_duration,
steps=32,
steps=steps,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
Expand Down Expand Up @@ -138,6 +139,12 @@ def generate(
default="output.wav",
help="Path to save the generated audio output",
)
parser.add_argument(
"--steps",
type=int,
default=32,
help="Number of steps to take when sampling the neural ODE",
)
parser.add_argument(
"--cfg",
type=float,
Expand Down Expand Up @@ -171,6 +178,7 @@ def generate(
model_name=args.model,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
steps=args.steps,
cfg_strength=args.cfg,
sway_sampling_coef=args.sway_coef,
speed=args.speed,
Expand Down

0 comments on commit b2b070c

Please sign in to comment.