Skip to content

Commit

Permalink
Add a script entrypoint for easier invocation.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucasnewman committed Oct 14, 2024
1 parent 9401d4f commit 808f13c
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 12 deletions.
19 changes: 14 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,30 @@ F5 is an evolution of [E2 TTS](https://arxiv.org/abs/2406.18009v2) and improves
pip install f5-tts-mlx
```

Pretrained model weights are available [on Hugging Face](https://huggingface.co/lucasnewman/f5-tts-mlx).

## Usage

See [examples/generate.py](./examples) for an example of generation.
```bash
python -m f5_tts_mlx.generate \
--text "The quick brown fox jumped over the lazy dog." \
--duration 3.5
```

See [examples/generate.py](./examples) for more options.


You can load a pretrained model from Python like this:

```python
from f5_tts_mlx.cfm import CFM
from f5_tts_mlx import F5TTS

f5tts = F5TTS.from_pretrained("lucasnewman/f5-tts-mlx")

f5tts = CFM.from_pretrained("lucasnewman/f5-tts-mlx")
audio = f5tts.sample(...)
```

Pretrained model weights are also available [on Hugging Face](https://huggingface.co/lucasnewman/f5-tts-mlx).

## Appreciation

[Yushen Chen](https://github.com/SWivid) for the original Pytorch implementation of F5 TTS and pretrained model.
Expand Down
4 changes: 2 additions & 2 deletions examples/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numpy as np

from f5_tts_mlx.cfm import CFM
from f5_tts_mlx.cfm import F5TTS
from f5_tts_mlx.utils import convert_char_to_pinyin

from vocos_mlx import Vocos
Expand All @@ -29,7 +29,7 @@ def generate(
seed: Optional[int] = None,
output_path: str = "output.wav",
):
f5tts = CFM.from_pretrained(model_name)
f5tts = F5TTS.from_pretrained(model_name)

# load reference audio
audio, sr = sf.read(ref_audio_path)
Expand Down
1 change: 1 addition & 0 deletions f5_tts_mlx/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .cfm import F5TTS
6 changes: 3 additions & 3 deletions f5_tts_mlx/cfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def list_str_to_idx(
# conditional flow matching


class CFM(nn.Module):
class F5TTS(nn.Module):
def __init__(
self,
transformer: nn.Module,
Expand Down Expand Up @@ -450,7 +450,7 @@ def fn(t, x):
def from_pretrained(
cls,
hf_model_name_or_path: str
) -> CFM:
) -> F5TTS:
path = fetch_from_hub(hf_model_name_or_path)

if path is None:
Expand All @@ -460,7 +460,7 @@ def from_pretrained(
vocab_path = path / "vocab.txt"
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}

f5tts = CFM(
f5tts = F5TTS(
transformer=DiT(
dim=1024,
depth=22,
Expand Down
151 changes: 151 additions & 0 deletions f5_tts_mlx/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
import argparse
import datetime
import pkgutil
from typing import Optional

import mlx.core as mx

import numpy as np

from f5_tts_mlx.cfm import F5TTS
from f5_tts_mlx.utils import convert_char_to_pinyin

from vocos_mlx import Vocos

import soundfile as sf

SAMPLE_RATE = 24_000
HOP_LENGTH = 256
FRAMES_PER_SEC = SAMPLE_RATE / HOP_LENGTH
TARGET_RMS = 0.1


def generate(
generation_text: str,
duration: float,
model_name: str = "lucasnewman/f5-tts-mlx",
ref_audio_path: Optional[str] = None,
ref_audio_text: Optional[str] = None,
sway_sampling_coef: float = 0.0,
seed: Optional[int] = None,
output_path: str = "output.wav",
):
f5tts = F5TTS.from_pretrained(model_name)

if ref_audio_path is None:
data = pkgutil.get_data('f5_tts_mlx', 'tests/test_en_1_ref_short.wav')

# write to a temp file
tmp_ref_audio_file = '/tmp/ref.wav'
with open(tmp_ref_audio_file, 'wb') as f:
f.write(data)

if data is not None:
audio, sr = sf.read(tmp_ref_audio_file)
ref_audio_text = "Some call me nature, others call me mother nature."
else:
# load reference audio
audio, sr = sf.read(ref_audio_path)

audio = mx.array(audio)
ref_audio_duration = audio.shape[0] / SAMPLE_RATE

rms = mx.sqrt(mx.mean(mx.square(audio)))
if rms < TARGET_RMS:
audio = audio * TARGET_RMS / rms

# generate the audio for the given text
text = convert_char_to_pinyin([ref_audio_text + " " + generation_text])

frame_duration = int((ref_audio_duration + duration) * FRAMES_PER_SEC)
print(f"Generating {frame_duration} total frames of audio...")

start_date = datetime.datetime.now()
vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")

wave, _ = f5tts.sample(
mx.expand_dims(audio, axis=0),
text=text,
duration=frame_duration,
steps=32,
cfg_strength=1,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
vocoder=vocos.decode,
)

# trim the reference audio
wave = wave[audio.shape[0]:]
generated_duration = wave.shape[0] / SAMPLE_RATE
elapsed_time = datetime.datetime.now() - start_date

print(f"Generated {generated_duration:.2f} seconds of audio in {elapsed_time}.")

sf.write(output_path, np.array(wave), SAMPLE_RATE)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate audio from text using f5-tts-mlx"
)

parser.add_argument(
"--model",
type=str,
default="lucasnewman/f5-tts-mlx",
help="Name of the model to use",
)
parser.add_argument(
"--text", type=str, required=True, help="Text to generate speech from"
)
parser.add_argument(
"--duration",
type=float,
required=True,
help="Duration of the generated audio in seconds",
)
parser.add_argument(
"--ref-audio",
type=str,
default=None,
help="Path to the reference audio file",
)
parser.add_argument(
"--ref-text",
type=str,
default=None,
help="Text spoken in the reference audio",
)
parser.add_argument(
"--output",
type=str,
default="output.wav",
help="Path to save the generated audio output",
)

parser.add_argument(
"--sway-coef",
type=float,
default="0.0",
help="Coefficient for sway sampling",
)

parser.add_argument(
"--seed",
type=int,
default=None,
help="Seed for noise generation",
)

args = parser.parse_args()

generate(
generation_text=args.text,
duration=args.duration,
model_name=args.model,
ref_audio_path=args.ref_audio,
ref_audio_text=args.ref_text,
sway_sampling_coef=args.sway_coef,
seed=args.seed,
output_path=args.output,
)
Binary file added f5_tts_mlx/tests/test_en_1_ref_short.wav
Binary file not shown.
7 changes: 5 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "f5-tts-mlx"
version = "0.0.5"
version = "0.0.6"
authors = [{name = "Lucas Newman", email = "lucasnewman@me.com"}]
license = {text = "MIT"}
description = "F5-TTS - MLX"
Expand Down Expand Up @@ -46,4 +46,7 @@ Homepage = "https://github.com/lucasnewman/f5-tts-mlx"
packages = ["f5_tts_mlx"]

[tool.setuptools.package-data]
f5_tts_mlx = ["assets/mel_filters.npz"]
f5_tts_mlx = ["assets/mel_filters.npz", "tests/test_en_1_ref_short.wav"]

[project.scripts]
f5-tts-mlx = "f5_tts_mlx.generate:main"

0 comments on commit 808f13c

Please sign in to comment.