Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(vLLM): support async generator #746

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 88 additions & 86 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import logging
import tempfile
import uuid
from dataclasses import dataclass, asdict
from typing import Literal, Optional, List, Tuple, Dict, Union
from json import load
Expand Down Expand Up @@ -200,9 +201,10 @@ def infer(
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):
self.context.set(False)
res_gen = self._infer(
return self._infer(
text,
stream,
lang,
Expand All @@ -213,11 +215,8 @@ def infer(
do_homophone_replacement,
params_refine_text,
params_infer_code,
stream_batch_size,
)
if stream:
return res_gen
else:
return next(res_gen)

def interrupt(self):
self.context.set(True)
Expand Down Expand Up @@ -339,7 +338,7 @@ def _load(

return self.has_loaded()

def _infer(
async def _infer(
self,
text,
stream=False,
Expand All @@ -351,6 +350,7 @@ def _infer(
do_homophone_replacement=True,
params_refine_text=RefineTextParams(),
params_infer_code=InferCodeParams(),
stream_batch_size=16,
):

assert self.has_loaded(use_decoder=use_decoder)
Expand Down Expand Up @@ -384,41 +384,39 @@ def _infer(
yield text
return

if stream:
length = 0
pass_batch_count = 0
for result in self._infer_code(
length = 0
async for result in self._infer_code(
text,
stream,
self.device,
use_decoder,
params_infer_code,
stream_batch_size,
):
wavs = self._decode_to_wavs(
result.hiddens if use_decoder else result.ids,
use_decoder,
)
result.destroy()
if stream:
pass_batch_count += 1
if pass_batch_count <= params_infer_code.pass_first_n_batches:
continue
a = length
b = a + params_infer_code.stream_speed
if b > wavs.shape[1]:
b = wavs.shape[1]
new_wavs = wavs[:, a:b]
length = b
yield new_wavs
if result.finished:
yield wavs[:, length:]
else:
yield wavs
if stream:
new_wavs = wavs[:, length:]
# Identify rows with non-zero elements using np.any
# keep_rows = np.any(array != 0, axis=1)
keep_cols = np.sum(new_wavs != 0, axis=0) > 0
# Filter both rows and columns using slicing
yield new_wavs[:][:, keep_cols]
# Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop.
keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0

import librosa
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to introduce librosa. Modifying the original code makes sense. ex. replace

keep_cols = np.sum(new_wavs != 0, axis=0) > 0

with

# pseudo code without testing, just a hint
keep_cols = np.sum(abs(new_wavs) > 1e-6, axis=0) > 0


silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
silence_left = len(wavs[0])
else:
for i in range(len(silence_intervals)):
silence_left = silence_intervals[i][0]
if silence_left <= 0:
continue
new_wavs = wavs[:, length : length + silence_left]
length += len(new_wavs[0])
yield new_wavs

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
Expand Down Expand Up @@ -457,13 +455,14 @@ def _decode_to_wavs(
return wavs

@torch.no_grad()
def _infer_code(
async def _infer_code(
self,
text: Tuple[List[str], str],
stream: bool,
device: torch.device,
return_hidden: bool,
params: InferCodeParams,
stream_batch_size: int,
):

gpt = self.gpt
Expand Down Expand Up @@ -504,6 +503,17 @@ def _infer_code(
repetition_penalty=params.repetition_penalty,
)

speaker_embedding_param = self.embed(input_ids, text_mask)
del text_mask
if params.spk_emb is not None:
self.speaker.apply(
speaker_embedding_param,
params.spk_emb,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
)

if gpt.is_vllm:
from .model.velocity import SamplingParams

Expand All @@ -520,64 +530,56 @@ def _infer_code(
input_ids = [i.tolist() for i in input_ids]

result = gpt.llm.generate(
None,
sample_params,
input_ids,
None, sample_params, uuid.uuid4(), speaker_embedding_param, input_ids[0]
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)

del text_mask, input_ids

return [
GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]

emb = self.embed(input_ids, text_mask)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you move a line up, don't change the variable name when it's ok to remain the original one. It will make it difficult to see your changes.


del text_mask

if params.spk_emb is not None:
self.speaker.apply(
emb,
params.spk_emb,
async for i in result:
token_ids = []
hidden_states = []
if (
stream and len(i.outputs[0].token_ids) % stream_batch_size == 0
) or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)
else:
result = gpt.generate(
speaker_embedding_param,
input_ids,
self.tokenizer.spk_emb_ids,
self.gpt.device_gpt,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

result = gpt.generate(
fengyizhu marked this conversation as resolved.
Show resolved Hide resolved
emb,
input_ids,
temperature=torch.tensor(temperature, device=device),
eos_token=num_code,
attention_mask=attention_mask,
max_new_token=params.max_new_token,
min_new_token=params.min_new_token,
logits_processors=(*logits_processors, *logits_warpers),
infer_text=False,
return_hidden=return_hidden,
stream=stream,
show_tqdm=params.show_tqdm,
ensure_non_empty=params.ensure_non_empty,
stream_batch=params.stream_batch,
manual_seed=params.manual_seed,
context=self.context,
)

del emb, input_ids

return result
del speaker_embedding_param, input_ids
async for i in result:
token_ids = []
hidden_states = []
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
token_ids.append(i.ids[0])
hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device))
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
hiddens=hidden_states,
attentions=[],
)

@torch.no_grad()
def _refine_text(
Expand Down
14 changes: 7 additions & 7 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32",
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -273,6 +274,7 @@ class GenerationOutputs:
ids: List[torch.Tensor]
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
hiddens: List[torch.Tensor]
finished: bool

def destroy(self):
del_all(self.ids)
Expand All @@ -288,6 +290,7 @@ def _prepare_generation_outputs(
attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
hiddens: List[torch.Tensor],
infer_text: bool,
finished: bool,
) -> GenerationOutputs:
inputs_ids = [
inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
Expand All @@ -305,10 +308,11 @@ def _prepare_generation_outputs(
ids=inputs_ids,
attentions=attentions,
hiddens=hiddens,
finished=finished,
)

@torch.no_grad()
def generate(
async def generate(
self,
emb: torch.Tensor,
inputs_ids: torch.Tensor,
Expand Down Expand Up @@ -581,6 +585,7 @@ def generate(
attentions,
hiddens,
infer_text,
False,
)
del not_finished

Expand All @@ -604,10 +609,5 @@ def generate(
del finish, inputs_ids_buf

yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True
)
Loading