Skip to content

Commit

Permalink
chore(format): run black on dev
Browse files Browse the repository at this point in the history
  • Loading branch information
github-actions[bot] committed Sep 12, 2024
1 parent ecc7b52 commit 5b7863f
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 37 deletions.
15 changes: 6 additions & 9 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,7 @@ async def _infer(
keep_cols = np.sum(abs(wavs[0][length:]) > 1e-6, axis=0) > 0

import librosa

silence_intervals = librosa.effects.split(wavs[0][length:], top_db=10)
silence_left = 0
if len(silence_intervals) == 0:
Expand Down Expand Up @@ -529,16 +530,14 @@ async def _infer_code(
input_ids = [i.tolist() for i in input_ids]

result = gpt.llm.generate(
None,
sample_params,
uuid.uuid4(),
speaker_embedding_param,
input_ids[0]
None, sample_params, uuid.uuid4(), speaker_embedding_param, input_ids[0]
)
async for i in result:
token_ids = []
hidden_states = []
if (stream and len(i.outputs[0].token_ids) % stream_batch_size == 0) or i.finished:
if (
stream and len(i.outputs[0].token_ids) % stream_batch_size == 0
) or i.finished:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
Expand Down Expand Up @@ -574,9 +573,7 @@ async def _infer_code(
hidden_states = []
if (stream and len(i.ids[0]) % stream_batch_size == 0) or i.finished:
token_ids.append(i.ids[0])
hidden_states.append(
i.hiddens[0].to(torch.float32).to(self.device)
)
hidden_states.append(i.hiddens[0].to(torch.float32).to(self.device))
yield GPT.GenerationOutputs(
ids=token_ids,
finished=i.finished,
Expand Down
12 changes: 3 additions & 9 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def from_pretrained(
num_audio_tokens=self.num_audio_tokens,
num_text_tokens=self.num_text_tokens,
post_model_path=embed_file_path,
dtype="float32"
dtype="float32",
)
self.logger.info("vLLM model loaded")
return
Expand Down Expand Up @@ -585,7 +585,7 @@ async def generate(
attentions,
hiddens,
infer_text,
False
False,
)
del not_finished

Expand All @@ -609,11 +609,5 @@ async def generate(
del finish, inputs_ids_buf

yield self._prepare_generation_outputs(
inputs_ids,
start_idx,
end_idx,
attentions,
hiddens,
infer_text,
True
inputs_ids, start_idx, end_idx, attentions, hiddens, infer_text, True
)
12 changes: 6 additions & 6 deletions ChatTTS/model/velocity/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,12 @@ def __init__(
self.request_counter = Counter()

async def generate(
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
speaker_embedding_param: torch.Tensor,
prompt_token_ids: Optional[List[int]] = None,
self,
prompt: Optional[str],
sampling_params: SamplingParams,
request_id: str,
speaker_embedding_param: torch.Tensor,
prompt_token_ids: Optional[List[int]] = None,
) -> AsyncIterator[RequestOutput]:
"""Generate outputs for a request.
Expand Down
53 changes: 40 additions & 13 deletions ChatTTS/model/velocity/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ def set_block_size(self, block_size: int) -> None:
def _prepare_prompt(
self,
seq_group_metadata_list: List[SequenceGroupMetadata],
) -> tuple[list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]]:
) -> tuple[
list[list[int]], list[list[int]], InputMetadata, list[int], list[Tensor]
]:
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
Expand Down Expand Up @@ -359,17 +361,23 @@ def _prepare_sample(
def prepare_input_tensors(
self,
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]]:
) -> Tuple[
torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, list[torch.Tensor]
]:
speaker_embedding = None
if self.is_driver_worker:
# NOTE: We assume that all sequences in the group are all prompts or
# all decodes.
is_prompt = seq_group_metadata_list[0].is_prompt
# Prepare input tensors.
if is_prompt:
(input_tokens, input_positions, input_metadata, prompt_lens, speaker_embedding) = (
self._prepare_prompt(seq_group_metadata_list)
)
(
input_tokens,
input_positions,
input_metadata,
prompt_lens,
speaker_embedding,
) = self._prepare_prompt(seq_group_metadata_list)
else:
(input_tokens, input_positions, input_metadata) = self._prepare_decode(
seq_group_metadata_list
Expand Down Expand Up @@ -461,7 +469,13 @@ def get_size_or_none(x: Optional[torch.Tensor]):
perform_sampling=False,
)

return input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding
return (
input_tokens,
input_positions,
input_metadata,
sampling_metadata,
speaker_embedding,
)

@torch.inference_mode()
def execute_model(
Expand All @@ -470,9 +484,13 @@ def execute_model(
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
) -> Optional[SamplerOutput]:

input_tokens, input_positions, input_metadata, sampling_metadata, speaker_embedding = (
self.prepare_input_tensors(seq_group_metadata_list)
)
(
input_tokens,
input_positions,
input_metadata,
sampling_metadata,
speaker_embedding,
) = self.prepare_input_tensors(seq_group_metadata_list)
# print(sampling_metadata.seq_data)
seq_groups = []
for i, rtn in enumerate(sampling_metadata.seq_groups):
Expand Down Expand Up @@ -521,7 +539,9 @@ def execute_model(
if speaker_embedding_params is None:
speaker_embedding_params = speaker_embedding[i]
else:
speaker_embedding_params = torch.cat((speaker_embedding_params, speaker_embedding[i]))
speaker_embedding_params = torch.cat(
(speaker_embedding_params, speaker_embedding[i])
)

else:
speaker_embedding_params = self.post_model(input_tokens, text_mask)
Expand Down Expand Up @@ -559,7 +579,7 @@ def execute_model(
# sampling_metadata=sampling_metadata,
# )
results = []
for i,val in enumerate(seq_groups):
for i, val in enumerate(seq_groups):
idx_next_i = idx_next[i, 0, :].tolist()
logprob_i = logprob[i].tolist()
tmp_hidden_states = hidden_states[i]
Expand Down Expand Up @@ -780,7 +800,9 @@ def _make_tensor_with_pad(
for x_i in x:
pad_i = pad
if isinstance(x[0][0], list):
pad_i = [0,] * len(x[0][0])
pad_i = [
0,
] * len(x[0][0])
elif isinstance(x[0][0], tuple):
pad_i = (0,) * len(x[0][0])
padded_x.append(_pad_to_max(x_i, max_len, pad_i))
Expand All @@ -790,6 +812,7 @@ def _make_tensor_with_pad(
device=device,
)


def _make_with_pad(
x: List[torch.Tensor],
max_len: int,
Expand All @@ -804,11 +827,15 @@ def _make_with_pad(
padded_x.append(x_i)
else:
padded_x.append(
torch.cat((torch.zeros(1, max_len-x_i.shape[-2], 768).to(device), x_i), dim=1)
torch.cat(
(torch.zeros(1, max_len - x_i.shape[-2], 768).to(device), x_i),
dim=1,
)
)

return padded_x


def _get_graph_batch_size(batch_size: int) -> int:
if batch_size <= 2:
return batch_size
Expand Down
2 changes: 2 additions & 0 deletions tools/audio/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@ def float_to_int16(audio: np.ndarray) -> np.ndarray:
am = 32767 * 32768 // am
return np.multiply(audio, am).astype(np.int16)


def pcm_to_bytes(pcm_data: np.ndarray) -> bytes:
return float_to_int16(pcm_data).tobytes()


def pcm_to_wav_bytes(pcm_data: np.ndarray) -> bytes:
buf = io.BytesIO()
with wave.open(buf, "wb") as wf:
Expand Down

0 comments on commit 5b7863f

Please sign in to comment.