diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c38ad8957..e20e47678 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -1,6 +1,7 @@ import os import logging import tempfile +import uuid from dataclasses import dataclass, asdict from typing import Literal, Optional, List, Tuple, Dict, Union from json import load @@ -200,9 +201,10 @@ def infer( do_homophone_replacement=True, params_refine_text=RefineTextParams(), params_infer_code=InferCodeParams(), + stream_batch_size=16, ): self.context.set(False) - res_gen = self._infer( + return self._infer( text, stream, lang, @@ -213,11 +215,8 @@ def infer( do_homophone_replacement, params_refine_text, params_infer_code, + stream_batch_size, ) - if stream: - return res_gen - else: - return next(res_gen) def interrupt(self): self.context.set(True) @@ -339,7 +338,7 @@ def _load( return self.has_loaded() - def _infer( + async def _infer( self, text, stream=False, @@ -351,6 +350,7 @@ def _infer( do_homophone_replacement=True, params_refine_text=RefineTextParams(), params_infer_code=InferCodeParams(), + stream_batch_size=16, ): assert self.has_loaded(use_decoder=use_decoder) @@ -384,41 +384,39 @@ def _infer( yield text return - if stream: - length = 0 - pass_batch_count = 0 - for result in self._infer_code( + length = 0 + async for result in self._infer_code( text, stream, self.device, use_decoder, params_infer_code, + stream_batch_size, ): wavs = self._decode_to_wavs( result.hiddens if use_decoder else result.ids, use_decoder, ) - result.destroy() - if stream: - pass_batch_count += 1 - if pass_batch_count <= params_infer_code.pass_first_n_batches: - continue - a = length - b = a + params_infer_code.stream_speed - if b > wavs.shape[1]: - b = wavs.shape[1] - new_wavs = wavs[:, a:b] - length = b - yield new_wavs + if result.finished: + yield wavs[:, length:] else: - yield wavs - if stream: - new_wavs = wavs[:, length:] - # Identify rows with non-zero elements using np.any - # keep_rows = np.any(array != 0, axis=1) - keep_cols = np.sum(new_wavs != 0, axis=0) > 0 - # Filter both rows and columns using slicing - yield new_wavs[:][:, keep_cols] + # Hacker:Check if there are any silent segments; if so, take the last segment. Otherwise, try waiting for another loop. + keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0 + + import librosa + + silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10) + silence_left = 0 + if len(silence_intervals) == 0: + silence_left = len(wavs[0]) + else: + for i in range(len(silence_intervals)): + silence_left = silence_intervals[i][0] + if silence_left <= 0: + continue + new_wavs = wavs[:, length : length + silence_left] + length += len(new_wavs[0]) + yield new_wavs @torch.inference_mode() def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: @@ -457,13 +455,14 @@ def _decode_to_wavs( return wavs @torch.no_grad() - def _infer_code( + async def _infer_code( self, text: Tuple[List[str], str], stream: bool, device: torch.device, return_hidden: bool, params: InferCodeParams, + stream_batch_size: int, ): gpt = self.gpt @@ -504,6 +503,17 @@ def _infer_code( repetition_penalty=params.repetition_penalty, ) + speaker_embedding_param = self.embed(input_ids, text_mask) + del text_mask + if params.spk_emb is not None: + self.speaker.apply( + speaker_embedding_param, + params.spk_emb, + input_ids, + self.tokenizer.spk_emb_ids, + self.gpt.device_gpt, + ) + if gpt.is_vllm: from .model.velocity import SamplingParams @@ -520,64 +530,56 @@ def _infer_code( input_ids = [i.tolist() for i in input_ids] result = gpt.llm.generate( - None, - sample_params, - input_ids, + None, sample_params, uuid.uuid4(), speaker_embedding_param, input_ids[0] ) - - token_ids = [] - hidden_states = [] - for i in result: - token_ids.append(torch.tensor(i.outputs[0].token_ids)) - hidden_states.append( - i.outputs[0].hidden_states.to(torch.float32).to(self.device) - ) - - del text_mask, input_ids - - return [ - GPT.GenerationOutputs( - ids=token_ids, - hiddens=hidden_states, - attentions=[], - ), - ] - - emb = self.embed(input_ids, text_mask) - - del text_mask - - if params.spk_emb is not None: - self.speaker.apply( - emb, - params.spk_emb, + async for i in result: + token_ids = [] + hidden_states = [] + if ( + stream and len(i.outputs[0].token_ids) % stream_batch_size == 0 + ) or i.finished: + token_ids.append(torch.tensor(i.outputs[0].token_ids)) + hidden_states.append( + i.outputs[0].hidden_states.to(torch.float32).to(self.device) + ) + yield GPT.GenerationOutputs( + ids=token_ids, + finished=i.finished, + hiddens=hidden_states, + attentions=[], + ) + else: + result = gpt.generate( + speaker_embedding_param, input_ids, - self.tokenizer.spk_emb_ids, - self.gpt.device_gpt, + temperature=torch.tensor(temperature, device=device), + eos_token=num_code, + attention_mask=attention_mask, + max_new_token=params.max_new_token, + min_new_token=params.min_new_token, + logits_processors=(*logits_processors, *logits_warpers), + infer_text=False, + return_hidden=return_hidden, + stream=stream, + show_tqdm=params.show_tqdm, + ensure_non_empty=params.ensure_non_empty, + stream_batch=params.stream_batch, + manual_seed=params.manual_seed, + context=self.context, ) - - result = gpt.generate( - emb, - input_ids, - temperature=torch.tensor(temperature, device=device), - eos_token=num_code, - attention_mask=attention_mask, - max_new_token=params.max_new_token, - min_new_token=params.min_new_token, - logits_processors=(*logits_processors, *logits_warpers), - infer_text=False, - return_hidden=return_hidden, - stream=stream, - show_tqdm=params.show_tqdm, - ensure_non_empty=params.ensure_non_empty, - stream_batch=params.stream_batch, - manual_seed=params.manual_seed, - context=self.context, - ) - - del emb, input_ids - - return result + del speaker_embedding_param, input_ids + async for i in result: + token_ids = [] + hidden_states = [] + if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished: + token_ids.append(i.ids[0]) + hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device)) + yield GPT.GenerationOutputs( + ids=token_ids, + finished=i.finished, + hiddens=hidden_states, + attentions=[], + ) @torch.no_grad() def _refine_text( diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index 576ecdfc4..3b700f338 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -68,6 +68,7 @@ def from_pretrained( num_audio_tokens=self.num_audio_tokens, num_text_tokens=self.num_text_tokens, post_model_path=embed_file_path, + dtype="float32", ) self.logger.info("vLLM model loaded") return @@ -273,6 +274,7 @@ class GenerationOutputs: ids: List[torch.Tensor] attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] hiddens: List[torch.Tensor] + finished: bool def destroy(self): del_all(self.ids) @@ -288,6 +290,7 @@ def _prepare_generation_outputs( attentions: List[Optional[Tuple[torch.FloatTensor, ...]]], hiddens: List[torch.Tensor], infer_text: bool, + finished: bool, ) -> GenerationOutputs: inputs_ids = [ inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx) @@ -305,10 +308,11 @@ def _prepare_generation_outputs( ids=inputs_ids, attentions=attentions, hiddens=hiddens, + finished=finished, ) @torch.no_grad() - def generate( + async def generate( self, emb: torch.Tensor, inputs_ids: torch.Tensor, @@ -581,6 +585,7 @@ def generate( attentions, hiddens, infer_text, + False, ) del not_finished @@ -604,10 +609,5 @@ def generate( del finish, inputs_ids_buf yield self._prepare_generation_outputs( - inputs_ids, - start_idx, - end_idx, - attentions, - hiddens, - infer_text, + inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True ) diff --git a/ChatTTS/model/velocity/async_llm_engine.py b/ChatTTS/model/velocity/async_llm_engine.py new file mode 100644 index 000000000..00089b722 --- /dev/null +++ b/ChatTTS/model/velocity/async_llm_engine.py @@ -0,0 +1,529 @@ +import asyncio +import time +from functools import partial +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Type, + Union, + AsyncIterator, +) + +import torch +from vllm.config import ModelConfig +from vllm.engine.arg_utils import AsyncEngineArgs +from .llm_engine import LLMEngine +from vllm.engine.ray_utils import initialize_cluster, ray +from vllm.logger import init_logger +from .output import RequestOutput +from .sampling_params import SamplingParams + +logger = init_logger(__name__) + + +class AsyncEngineDeadError(RuntimeError): + pass + + +def _raise_exception_on_finish( + task: asyncio.Task, request_tracker: "RequestTracker" +) -> None: + msg = ( + "Task finished unexpectedly. This should never happen! " + "Please open an issue on Github." + ) + try: + try: + task.result() + except asyncio.CancelledError: + return + except Exception as exc: + raise AsyncEngineDeadError( + msg + " See stack trace above for the actual cause." + ) from exc + raise AsyncEngineDeadError(msg) + except Exception as exc: + request_tracker.propagate_exception(exc) + raise exc + + +class AsyncStream: + """A stream of RequestOutputs for a request that can be + iterated over asynchronously.""" + + def __init__(self, request_id: str) -> None: + self.request_id = request_id + self._queue = asyncio.Queue() + self._finished = False + + def put(self, item: RequestOutput) -> None: + if self._finished: + return + self._queue.put_nowait(item) + + def finish(self) -> None: + self._queue.put_nowait(StopIteration) + self._finished = True + + @property + def finished(self) -> bool: + return self._finished + + def __aiter__(self): + return self + + async def __anext__(self) -> RequestOutput: + result = await self._queue.get() + if result is StopIteration: + raise StopAsyncIteration + elif isinstance(result, Exception): + raise result + return result + + +class RequestTracker: + """Synchronous abstraction for tracking requests.""" + + def __init__(self) -> None: + self._request_streams: Dict[str, AsyncStream] = {} + self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() + self.new_requests_event = None + + def __contains__(self, item): + return item in self._request_streams + + def init_event(self): + self.new_requests_event = asyncio.Event() + + def propagate_exception( + self, exc: Exception, request_id: Optional[str] = None + ) -> None: + """Propagate an exception to request streams + (all if request_id is None).""" + if request_id is not None: + self._request_streams[request_id].put(exc) + else: + for stream in self._request_streams.values(): + stream.put(exc) + + def process_request_output( + self, request_output: RequestOutput, *, verbose: bool = False + ) -> None: + """Process a request output from the engine.""" + request_id = request_output.request_id + + self._request_streams[request_id].put(request_output) + if request_output.finished: + if verbose: + logger.info(f"Finished request {request_id}.") + self.abort_request(request_id) + + def add_request(self, request_id: str, **engine_add_request_kwargs) -> AsyncStream: + """Add a request to be sent to the engine on the next background + loop iteration.""" + if request_id in self._request_streams: + raise KeyError(f"Request {request_id} already exists.") + + stream = AsyncStream(request_id) + self._new_requests.put_nowait( + (stream, {"request_id": request_id, **engine_add_request_kwargs}) + ) + + self.new_requests_event.set() + + return stream + + def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + """Abort a request during next background loop iteration.""" + if verbose: + logger.info(f"Aborted request {request_id}.") + + self._finished_requests.put_nowait(request_id) + + if ( + request_id not in self._request_streams + or self._request_streams[request_id].finished + ): + # The request has already finished or been aborted. + return + + self._request_streams[request_id].finish() + + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: + """Get the new requests and finished requests to be + sent to the engine.""" + new_requests: List[Dict] = [] + finished_requests: Set[str] = set() + + while not self._finished_requests.empty(): + request_id = self._finished_requests.get_nowait() + finished_requests.add(request_id) + self._request_streams.pop(request_id, None) + + while not self._new_requests.empty(): + stream, new_request = self._new_requests.get_nowait() + if stream.request_id in finished_requests: + # The request has already been aborted. + stream.finish() + continue + self._request_streams[stream.request_id] = stream + new_requests.append(new_request) + + self.new_requests_event.clear() + + return new_requests, finished_requests + + async def wait_for_new_requests(self): + await self.new_requests_event.wait() + + +class _AsyncLLMEngine(LLMEngine): + """Extension of LLMEngine to add async methods.""" + + async def step_async(self) -> List[RequestOutput]: + """Performs one decoding iteration and returns newly generated results. + The workers are ran asynchronously if possible. + + This function performs one decoding iteration of the engine. It first + schedules the sequences to be executed in the next iteration and the + token blocks to be swapped in/out/copy. Then, it executes the model + and updates the scheduler with the model outputs. Finally, it decodes + the sequences and returns the newly generated results. + """ + seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() + + if not scheduler_outputs.is_empty(): + # Execute the model. + all_outputs = await self._run_workers_async( + "execute_model", + driver_kwargs={ + "seq_group_metadata_list": seq_group_metadata_list, + "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, + "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, + "blocks_to_copy": scheduler_outputs.blocks_to_copy, + }, + ) + + # Only the driver worker returns the sampling results. + output = all_outputs[0] + else: + output = [] + + return self._process_model_outputs(output, scheduler_outputs) + + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[List[Any]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + coros = [] + + if driver_args is None: + driver_args = args + if driver_kwargs is None: + driver_kwargs = kwargs + + # Run the driver worker asynchronously. + driver_executor = getattr(self.driver_worker, method) + coros.append( + asyncio.get_event_loop().run_in_executor( + None, partial(driver_executor, *driver_args, **driver_kwargs) + ) + ) + + # Run the ray workers asynchronously. + for worker in self.workers: + coros.append(worker.execute_method.remote(method, *args, **kwargs)) + + all_outputs = await asyncio.gather(*coros) + return all_outputs + + +class AsyncLLMEngine: + """An asynchronous wrapper for LLMEngine. + + This class is used to wrap the LLMEngine class to make it asynchronous. It + uses asyncio to create a background loop that keeps processing incoming + requests. The LLMEngine is kicked by the generate method when there + are requests in the waiting queue. The generate method yields the outputs + from the LLMEngine to the caller. + + NOTE: For the comprehensive list of arguments, see `LLMEngine`. + + Args: + worker_use_ray: Whether to use Ray for model workers. Required for + distributed execution. Should be the same as + `parallel_config.worker_use_ray`. + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the + async frontend will be executed in a separate process as the + model workers. + log_requests: Whether to log the requests. + start_engine_loop: If True, the background task to run the engine + will be automatically started in the generate call. + *args, *kwargs: Arguments for LLMEngine. + """ + + _engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine + + def __init__( + self, + worker_use_ray: bool, + engine_use_ray: bool, + *args, + log_requests: bool = False, + max_log_len: Optional[int] = None, + start_engine_loop: bool = True, + **kwargs, + ) -> None: + self.worker_use_ray = worker_use_ray + self.engine_use_ray = engine_use_ray + self.log_requests = log_requests + self.max_log_len = max_log_len + self.engine = self._init_engine(*args, **kwargs) + + self.background_loop = None + # We need to keep a reference to unshielded + # task as well to prevent it from being garbage + # collected + self._background_loop_unshielded = None + self.start_engine_loop = start_engine_loop + self._request_tracker = RequestTracker() + + @property + def is_running(self) -> bool: + return self.background_loop is not None and not self.background_loop.done() + + def start_background_loop(self) -> None: + """Start the background loop.""" + if self.is_running: + raise RuntimeError("Background loop is already running.") + self._request_tracker.init_event() + + self._background_loop_unshielded = asyncio.get_event_loop().create_task( + self.run_engine_loop() + ) + self._background_loop_unshielded.add_done_callback( + partial(_raise_exception_on_finish, request_tracker=self._request_tracker) + ) + self.background_loop = asyncio.shield(self._background_loop_unshielded) + + def _init_engine(self, *args, **kwargs) -> Union[_AsyncLLMEngine, "ray.ObjectRef"]: + if not self.engine_use_ray: + engine_class = self._engine_class + elif self.worker_use_ray: + engine_class = ray.remote(num_cpus=0)(self._engine_class).remote + else: + # FIXME(woosuk): This is a bit hacky. Be careful when changing the + # order of the arguments. + cache_config = args[1] + parallel_config = args[2] + if parallel_config.tensor_parallel_size == 1: + num_gpus = cache_config.gpu_memory_utilization + else: + num_gpus = 1 + engine_class = ray.remote(num_gpus=num_gpus)(self._engine_class).remote + return engine_class(*args, **kwargs) + + async def engine_step(self) -> bool: + """Kick the engine to process the waiting requests. + + Returns True if there are in-progress requests.""" + + new_requests, finished_requests = ( + self._request_tracker.get_new_and_finished_requests() + ) + + for new_request in new_requests: + # Add the request into the vLLM engine's waiting queue. + # TODO: Maybe add add_request_batch to reduce Ray overhead + if self.engine_use_ray: + await self.engine.add_request.remote(**new_request) + else: + self.engine.add_request(**new_request) + + if finished_requests: + await self._engine_abort(finished_requests) + + if self.engine_use_ray: + request_outputs = await self.engine.step.remote() + else: + request_outputs = await self.engine.step_async() + + # Put the outputs into the corresponding streams. + for request_output in request_outputs: + self._request_tracker.process_request_output( + request_output, verbose=self.log_requests + ) + + return len(request_outputs) > 0 + + async def _engine_abort(self, request_ids: Iterable[str]): + if self.engine_use_ray: + await self.engine.abort_request.remote(request_ids) + else: + self.engine.abort_request(request_ids) + + async def run_engine_loop(self): + # Initialize the RequestTracker here so it uses the right event loop. + has_requests_in_progress = False + while True: + if not has_requests_in_progress: + await self._request_tracker.wait_for_new_requests() + has_requests_in_progress = await self.engine_step() + await asyncio.sleep(0) + + async def add_request( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + speaker_embedding_param: torch.Tensor, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + ) -> AsyncStream: + + if not self.is_running: + if self.start_engine_loop: + self.start_background_loop() + else: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError)." + ) + + stream = self._request_tracker.add_request( + request_id, + prompt=prompt, + sampling_params=sampling_params, + speaker_embedding_param=speaker_embedding_param, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + ) + + return stream + + async def generate( + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + speaker_embedding_param: torch.Tensor, + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncIterator[RequestOutput]: + """Generate outputs for a request. + + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. + + Args: + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we + use the tokenizer to convert the prompts to token IDs. + + Yields: + The output `RequestOutput` objects from the LLMEngine for the + request. + """ + # Preprocess the request. + # This should not be used for logging, as it is monotonic time. + arrival_time = time.monotonic() + + try: + stream = await self.add_request( + request_id, + prompt, + sampling_params, + speaker_embedding_param=speaker_embedding_param, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + ) + + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self._abort(request_id) + raise e + + async def abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + if not self.is_running: + raise AsyncEngineDeadError( + "Background loop is not running. If it was running, " + "inspect the output to find the stacktrace of the " + "error that caused the background loop to stop " + "(AsyncEngineDeadError)." + ) + + return self._abort(request_id) + + def _abort(self, request_id: str) -> None: + """Abort a request. + + Abort a submitted request. If the request is finished or not found, + this method will be a no-op. + + Args: + request_id: The unique id of the request. + """ + self._request_tracker.abort_request(request_id, verbose=self.log_requests) + + async def get_model_config(self) -> ModelConfig: + """Get the model configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_model_config.remote() + else: + return self.engine.get_model_config() + + @classmethod + def from_engine_args( + cls, + engine_args: AsyncEngineArgs, + post_model_path: str = None, + start_engine_loop: bool = True, + ) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] + # Initialize the cluster. + placement_group = initialize_cluster( + parallel_config, engine_args.engine_use_ray + ) + # Create the async LLM engine. + engine = cls( + parallel_config.worker_use_ray, + engine_args.engine_use_ray, + *engine_configs, + placement_group, + log_requests=not engine_args.disable_log_requests, + log_stats=True, + max_log_len=engine_args.max_log_len, + start_engine_loop=start_engine_loop, + post_model_path=post_model_path, + ) + return engine diff --git a/ChatTTS/model/velocity/configs.py b/ChatTTS/model/velocity/configs.py index c578f468a..2b8e2d76b 100644 --- a/ChatTTS/model/velocity/configs.py +++ b/ChatTTS/model/velocity/configs.py @@ -578,6 +578,9 @@ class EngineArgs: max_context_len_to_capture: int = 8192 num_audio_tokens: int = 1024 num_text_tokens: int = 80 + engine_use_ray: bool = False + disable_log_requests: bool = False + max_log_len: int = 8192 def __post_init__(self): if self.tokenizer is None: diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index a37f5cb34..e5330ba42 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -1,11 +1,12 @@ -from typing import List, Optional, Union +import asyncio +import time +from typing import List, Optional, Union, AsyncIterator -from tqdm import tqdm -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from vllm.utils import Counter +import torch +from .async_llm_engine import AsyncLLMEngine from .configs import EngineArgs -from .llm_engine import LLMEngine from .output import RequestOutput from .sampling_params import SamplingParams @@ -107,107 +108,53 @@ def __init__( num_text_tokens=num_text_tokens, **kwargs, ) - self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) + self.llm_engine = AsyncLLMEngine.from_engine_args(engine_args, post_model_path) self.request_counter = Counter() - def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_engine.tokenizer - - def set_tokenizer( - self, - tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], - ) -> None: - self.llm_engine.tokenizer = tokenizer - - def generate( + async def generate( self, - prompts: Optional[Union[str, List[str]]] = None, - sampling_params: Optional[SamplingParams] = None, - prompt_token_ids: Optional[List[List[int]]] = None, - use_tqdm: bool = True, - ) -> List[RequestOutput]: - """Generates the completions for the input prompts. + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + speaker_embedding_param: torch.Tensor, + prompt_token_ids: Optional[List[int]] = None, + ) -> AsyncIterator[RequestOutput]: + """Generate outputs for a request. - NOTE: This class automatically batches the given prompts, considering - the memory constraint. For the best performance, put all of your prompts - into a single list and pass it to this method. + Generate outputs for a request. This method is a coroutine. It adds the + request into the waiting queue of the LLMEngine and streams the outputs + from the LLMEngine to the caller. Args: - prompts: A list of prompts to generate completions for. - sampling_params: The sampling parameters for text generation. If - None, we use the default sampling parameters. - prompt_token_ids: A list of token IDs for the prompts. If None, we + prompt: The prompt string. Can be None if prompt_token_ids is + provided. + sampling_params: The sampling parameters of the request. + request_id: The unique id of the request. + prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. - use_tqdm: Whether to use tqdm to display the progress bar. - Returns: - A list of `RequestOutput` objects containing the generated - completions in the same order as the input prompts. + Yields: + The output `RequestOutput` objects from the LLMEngine for the + request. """ - if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " "provided.") - if isinstance(prompts, str): - # Convert a single prompt to a list. - prompts = [prompts] - if ( - prompts is not None - and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids) - ): - raise ValueError( - "The lengths of prompts and prompt_token_ids " "must be the same." + # Preprocess the request. + # This should not be used for logging, as it is monotonic time. + arrival_time = time.monotonic() + + try: + stream = await self.llm_engine.add_request( + request_id, + prompt, + sampling_params, + speaker_embedding_param=speaker_embedding_param, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, ) - if sampling_params is None: - # Use default sampling params. - sampling_params = SamplingParams() - - # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len(prompt_token_ids) - for i in range(num_requests): - prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[i] - self._add_request(prompt, sampling_params, token_ids) - - rtns = self._run_engine(use_tqdm) - for i, rtn in enumerate(rtns): - token_ids = rtn.outputs[0].token_ids - for j, token_id in enumerate(token_ids): - if len(token_id) == 1: - token_ids[j] = token_id[0] - else: - token_ids[j] = list(token_id) - - return rtns - - def _add_request( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - prompt_token_ids: Optional[List[int]], - ) -> None: - request_id = str(next(self.request_counter)) - self.llm_engine.add_request( - request_id, prompt, sampling_params, prompt_token_ids - ) - def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: - # Initialize tqdm. - if use_tqdm: - num_requests = self.llm_engine.get_num_unfinished_requests() - pbar = tqdm(total=num_requests, desc="Processed prompts") - # Run the engine. - outputs: List[RequestOutput] = [] - while self.llm_engine.has_unfinished_requests(): - step_outputs = self.llm_engine.step() - for output in step_outputs: - if output.finished: - outputs.append(output) - if use_tqdm: - pbar.update(1) - if use_tqdm: - pbar.close() - # Sort the outputs by request ID. - # This is necessary because some requests may be finished earlier than - # its previous requests. - outputs = sorted(outputs, key=lambda x: int(x.request_id)) - return outputs + async for request_output in stream: + yield request_output + except (Exception, asyncio.CancelledError) as e: + # If there is an exception or coroutine is cancelled, abort the + # request. + self.llm_engine._abort(request_id) + raise e diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py index 0d144d0fd..b4894fc11 100644 --- a/ChatTTS/model/velocity/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -22,7 +22,7 @@ ) from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port -import numpy as np +import torch if ray: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -328,6 +328,7 @@ def add_request( request_id: str, prompt: Optional[str], sampling_params: SamplingParams, + speaker_embedding_param: torch.Tensor, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, ) -> None: @@ -342,6 +343,7 @@ def add_request( prompt: The prompt string. Can be None if prompt_token_ids is provided. sampling_params: The sampling parameters for text generation. + speaker_embedding_param: The speaker embedding parameter prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. arrival_time: The arrival time of the request. If None, we use @@ -354,10 +356,14 @@ def add_request( # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence( + seq_id, prompt, prompt_token_ids, speaker_embedding_param, block_size + ) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time) + seq_group = SequenceGroup( + request_id, [seq], sampling_params, speaker_embedding_param, arrival_time + ) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index e5c703e09..e0cad0f0e 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -4,6 +4,7 @@ import numpy as np import torch import torch.nn as nn +from torch import Tensor from .configs import ModelConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger @@ -105,11 +106,14 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int]]: + ) -> tuple[ + list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor] + ]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + embedding: List[torch.Tensor] = [] prompt_lens: List[int] = [] for seq_group_metadata in seq_group_metadata_list: @@ -127,7 +131,7 @@ def _prepare_prompt( # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. input_positions.append(list(range(prompt_len))) - + embedding.append(seq_group_metadata.speaker_embedding_param) if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -166,6 +170,10 @@ def _prepare_prompt( slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long ) + embedding = _make_with_pad( + embedding, max_prompt_len, pad=0, dtype=torch.float32 + ) + input_metadata = InputMetadata( is_prompt=True, slot_mapping=slot_mapping, @@ -174,7 +182,7 @@ def _prepare_prompt( block_tables=None, use_cuda_graph=False, ) - return input_tokens, input_positions, input_metadata, prompt_lens + return input_tokens, input_positions, input_metadata, prompt_lens, embedding def _prepare_decode( self, @@ -353,16 +361,23 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + ) -> Tuple[ + torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor] + ]: + speaker_embedding = None if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, prompt_lens) = ( - self._prepare_prompt(seq_group_metadata_list) - ) + ( + input_tokens, + input_positions, + input_metadata, + prompt_lens, + speaker_embedding, + ) = self._prepare_prompt(seq_group_metadata_list) else: (input_tokens, input_positions, input_metadata) = self._prepare_decode( seq_group_metadata_list @@ -454,7 +469,13 @@ def get_size_or_none(x: Optional[torch.Tensor]): perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata + return ( + input_tokens, + input_positions, + input_metadata, + sampling_metadata, + speaker_embedding, + ) @torch.inference_mode() def execute_model( @@ -462,39 +483,23 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata = ( - self.prepare_input_tensors(seq_group_metadata_list) - ) + + ( + input_tokens, + input_positions, + input_metadata, + sampling_metadata, + speaker_embedding, + ) = self.prepare_input_tensors(seq_group_metadata_list) # print(sampling_metadata.seq_data) seq_groups = [] - input_tokens_history = [] for i, rtn in enumerate(sampling_metadata.seq_groups): seq_groups.append(rtn[0][0]) - tokens_history = sampling_metadata.seq_data[rtn[0][0]].output_token_ids - if len(tokens_history) >= 1: - if len(tokens_history[0]) == 1: - tokens_history = [token[0] for token in tokens_history] - else: - tokens_history = [list(token) for token in tokens_history] - input_tokens_history.append(tokens_history) - input_tokens_history = torch.tensor(input_tokens_history).to( - input_tokens.device - ) - # token_ids = rtn.outputs[0].token_ids - # for j, token_id in enumerate(token_ids): - # if len(token_id) == 1: - # token_ids[j] = token_id[0] - # else: - # token_ids[j] = list(token_id) # Execute the model. - # print("it1",input_tokens) if len(input_tokens.shape) == 2: input_tokens = input_tokens.unsqueeze(2).repeat(1, 1, 4) - if len(input_tokens_history.shape) == 2: - input_tokens_history = input_tokens_history.unsqueeze(2).repeat(1, 1, 4) - # print(input_tokens_history.shape) - # print("it2",input_tokens.shape) + text_mask = input_tokens != 0 text_mask = text_mask[:, :, 0] @@ -514,10 +519,10 @@ def execute_model( # print(logits_processors, logits_warpers) min_new_token = sampling_metadata.seq_groups[0][1].min_new_token eos_token = sampling_metadata.seq_groups[0][1].eos_token - start_idx = sampling_metadata.seq_groups[0][1].start_idx + start_idx = input_tokens[0].shape[0] if input_tokens.shape[-2] == 1: if infer_text: - input_emb: torch.Tensor = self.post_model.emb_text( + speaker_embedding_params: torch.Tensor = self.post_model.emb_text( input_tokens[:, :, 0] ) else: @@ -525,32 +530,34 @@ def execute_model( self.post_model.emb_code[i](input_tokens[:, :, i]) for i in range(self.post_model.num_vq) ] - input_emb = torch.stack(code_emb, 3).sum(3) - start_idx = ( - input_tokens_history.shape[-2] - 1 - if input_tokens_history.shape[-2] > 0 - else 0 - ) + speaker_embedding_params = torch.stack(code_emb, 3).sum(3) else: - input_emb = self.post_model(input_tokens, text_mask) - # print(input_emb.shape) + # 通过for循环,拼接成一个tensor + if seq_group_metadata_list[0].speaker_embedding_param is not None: + speaker_embedding_params = None + for i in range(input_tokens.shape[0]): + if speaker_embedding_params is None: + speaker_embedding_params = speaker_embedding[i] + else: + speaker_embedding_params = torch.cat( + (speaker_embedding_params, speaker_embedding[i]) + ) + + else: + speaker_embedding_params = self.post_model(input_tokens, text_mask) + hidden_states = model_executable( - input_emb=input_emb, + input_emb=speaker_embedding_params, positions=input_positions, kv_caches=kv_caches, input_metadata=input_metadata, ) # print(hidden_states.shape) # print(input_tokens) - B_NO_PAD = input_tokens_history.shape[0] - input_tokens = input_tokens[:B_NO_PAD, :, :] - hidden_states = hidden_states[:B_NO_PAD, :, :] + input_tokens = input_tokens[:, :, :] + hidden_states = hidden_states[:, :, :] idx_next, logprob, finish = self.sampler.sample( - inputs_ids=( - input_tokens - if input_tokens_history.shape[-2] == 0 - else input_tokens_history - ), + inputs_ids=input_tokens, hidden_states=hidden_states, infer_text=infer_text, temperature=temperture, @@ -572,7 +579,7 @@ def execute_model( # sampling_metadata=sampling_metadata, # ) results = [] - for i in range(idx_next.shape[0]): + for i, val in enumerate(seq_groups): idx_next_i = idx_next[i, 0, :].tolist() logprob_i = logprob[i].tolist() tmp_hidden_states = hidden_states[i] @@ -618,6 +625,7 @@ def profile_run(self) -> None: is_prompt=True, seq_data={group_id: seq_data}, sampling_params=sampling_params, + speaker_embedding_param=torch.zeros(1, seq_len, 768).to("cuda"), block_tables=None, ) seqs.append(seq) @@ -773,11 +781,11 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) -def _pad_to_max(x: List[int], max_len: int, pad: int) -> List[int]: +def _pad_to_max(x: List[int], max_len: int, pad: List[int]) -> List[int]: assert len(x) <= max_len if len(x) == max_len: return list(x) - return list(x) + [pad] * (max_len - len(x)) + return [pad] * (max_len - len(x)) + list(x) def _make_tensor_with_pad( @@ -791,18 +799,43 @@ def _make_tensor_with_pad( padded_x = [] for x_i in x: pad_i = pad - if isinstance(x[0][0], tuple): + if isinstance(x[0][0], list): + pad_i = [ + 0, + ] * len(x[0][0]) + elif isinstance(x[0][0], tuple): pad_i = (0,) * len(x[0][0]) padded_x.append(_pad_to_max(x_i, max_len, pad_i)) - return torch.tensor( padded_x, dtype=dtype, device=device, - pin_memory=pin_memory and str(device) == "cpu", ) +def _make_with_pad( + x: List[torch.Tensor], + max_len: int, + pad: int, + dtype: torch.dtype, + device: Union[str, torch.device] = "cuda", +) -> torch.Tensor: + padded_x = [] + for x_i in x: + assert x_i.shape[-2] <= max_len + if x_i.shape[-2] == max_len: + padded_x.append(x_i) + else: + padded_x.append( + torch.cat( + (torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i), + dim=1, + ) + ) + + return padded_x + + def _get_graph_batch_size(batch_size: int) -> int: if batch_size <= 2: return batch_size diff --git a/ChatTTS/model/velocity/scheduler.py b/ChatTTS/model/velocity/scheduler.py index 97d9cb450..f47bd4a7c 100644 --- a/ChatTTS/model/velocity/scheduler.py +++ b/ChatTTS/model/velocity/scheduler.py @@ -313,6 +313,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: is_prompt=scheduler_outputs.prompt_run, seq_data=seq_data, sampling_params=seq_group.sampling_params, + speaker_embedding_param=seq_group.speaker_embedding_param, block_tables=block_tables, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/ChatTTS/model/velocity/sequence.py b/ChatTTS/model/velocity/sequence.py index 76f9cf4e7..d3052a938 100644 --- a/ChatTTS/model/velocity/sequence.py +++ b/ChatTTS/model/velocity/sequence.py @@ -131,10 +131,12 @@ def __init__( seq_id: int, prompt: str, prompt_token_ids: List[int], + speaker_embedding_param: torch.Tensor, block_size: int, ) -> None: self.seq_id = seq_id self.prompt = prompt + self.speaker_embedding_param = speaker_embedding_param self.block_size = block_size self.data = SequenceData(prompt_token_ids) @@ -260,11 +262,13 @@ def __init__( request_id: str, seqs: List[Sequence], sampling_params: SamplingParams, + speaker_embedding_param: torch.Tensor, arrival_time: float, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params + self.speaker_embedding_param = speaker_embedding_param self.arrival_time = arrival_time self.prompt_logprobs: Optional[PromptLogprobs] = None @@ -366,12 +370,14 @@ def __init__( is_prompt: bool, seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, + speaker_embedding_param: torch.Tensor, block_tables: Dict[int, List[int]], ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params + self.speaker_embedding_param = speaker_embedding_param self.block_tables = block_tables diff --git a/examples/api/main.py b/examples/api/main.py index 1542b1740..a0e1affa7 100644 --- a/examples/api/main.py +++ b/examples/api/main.py @@ -1,11 +1,11 @@ -import io import os import sys -import zipfile +import numpy as np from fastapi import FastAPI -from fastapi.responses import StreamingResponse +from fastapi.responses import Response, StreamingResponse +from tools.audio.np import pcm_to_wav_bytes, pcm_to_bytes if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -13,11 +13,10 @@ now_dir = os.getcwd() sys.path.append(now_dir) -from typing import Optional +from typing import Optional, AsyncGenerator import ChatTTS -from tools.audio import pcm_arr_to_mp3_view from tools.logger import get_logger import torch @@ -36,7 +35,7 @@ async def startup_event(): chat = ChatTTS.Chat(get_logger("ChatTTS")) logger.info("Initializing ChatTTS...") - if chat.load(): + if chat.load(use_vllm=True): logger.info("Models loaded successfully.") else: logger.error("Models load failed.") @@ -52,8 +51,9 @@ class ChatTTSParams(BaseModel): use_decoder: bool = True do_text_normalization: bool = True do_homophone_replacement: bool = False - params_refine_text: ChatTTS.Chat.RefineTextParams - params_infer_code: ChatTTS.Chat.InferCodeParams + params_refine_text: Optional[ChatTTS.Chat.RefineTextParams] = None + params_infer_code: Optional[ChatTTS.Chat.InferCodeParams] = None + stream_batch_size: int = 16 @app.post("/generate_voice") @@ -66,10 +66,11 @@ async def generate_voice(params: ChatTTSParams): params.params_infer_code.spk_emb = chat.sample_random_speaker() # text seed for text refining - if params.params_refine_text: - text = chat.infer( + if params.params_refine_text and params.skip_refine_text is False: + results_generator = chat.infer( text=params.text, skip_refine_text=False, refine_text_only=True ) + text = await next(results_generator) logger.info(f"Refined text: {text}") else: # no text refining @@ -79,7 +80,8 @@ async def generate_voice(params: ChatTTSParams): logger.info(params.params_infer_code.spk_emb) logger.info("Start voice inference.") - wavs = chat.infer( + + results_generator = chat.infer( text=text, stream=params.stream, lang=params.lang, @@ -90,18 +92,24 @@ async def generate_voice(params: ChatTTSParams): params_infer_code=params.params_infer_code, params_refine_text=params.params_refine_text, ) - logger.info("Inference completed.") - - # zip all of the audio files together - buf = io.BytesIO() - with zipfile.ZipFile( - buf, "a", compression=zipfile.ZIP_DEFLATED, allowZip64=False - ) as f: - for idx, wav in enumerate(wavs): - f.writestr(f"{idx}.mp3", pcm_arr_to_mp3_view(wav)) - logger.info("Audio generation successful.") - buf.seek(0) - - response = StreamingResponse(buf, media_type="application/zip") - response.headers["Content-Disposition"] = "attachment; filename=audio_files.zip" - return response + + if params.stream: + + async def stream_results() -> AsyncGenerator[bytes, None]: + async for output in results_generator: + yield pcm_to_bytes(output[0]) + + return StreamingResponse( + content=stream_results(), media_type="text/event-stream" + ) + + output = None + async for request_output in results_generator: + if output is None: + output = request_output[0] + else: + output = np.concatenate((output, request_output[0]), axis=0) + output = pcm_to_wav_bytes(output) + return Response( + content=output, media_type="audio/wav", headers={"Cache-Control": "no-cache"} + ) diff --git a/tools/audio/np.py b/tools/audio/np.py index a1aee2047..cf96be937 100644 --- a/tools/audio/np.py +++ b/tools/audio/np.py @@ -1,4 +1,6 @@ +import io import math +import wave import numpy as np from numba import jit @@ -9,3 +11,20 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray: am = int(math.ceil(float(np.abs(audio).max())) * 32768) am = 32767 * 32768 // am return np.multiply(audio, am).astype(np.int16) + + +def pcm_to_bytes(pcm_data: np.ndarray) -> bytes: + return float_to_int16(pcm_data).tobytes() + + +def pcm_to_wav_bytes(pcm_data: np.ndarray) -> bytes: + buf = io.BytesIO() + with wave.open(buf, "wb") as wf: + wf.setnchannels(1) # Mono channel + wf.setsampwidth(2) # Sample width in bytes + wf.setframerate(24000) # Sample rate in Hz + wf.writeframes(float_to_int16(pcm_data)) + buf.seek(0, 0) + wav_data = buf.getvalue() + buf.close() + return wav_data