Skip to content

Commit

Permalink
[debug] support flow cache, for sharper tts_mel output
Browse files Browse the repository at this point in the history
  • Loading branch information
boji123 committed Sep 20, 2024
1 parent 95051e5 commit 283e612
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 9 deletions.
10 changes: 8 additions & 2 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(self,
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.flow_cache_dict = {}
self.mel_overlap_dict = {}
self.hift_cache_dict = {}

Expand Down Expand Up @@ -95,13 +96,17 @@ def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uui
self.llm_end_dict[uuid] = True

def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
tts_mel = self.flow.inference(token=token.to(self.device),
tts_mel, flow_cache = self.flow.inference(token=token.to(self.device),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device))
embedding=embedding.to(self.device),
required_cache_size=self.mel_overlap_len,
flow_cache=self.flow_cache_dict[uuid])
self.flow_cache_dict[uuid] = flow_cache

# mel overlap fade in out
if self.mel_overlap_dict[uuid] is not None:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
Expand Down Expand Up @@ -140,6 +145,7 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.flow_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
p.start()
Expand Down
12 changes: 8 additions & 4 deletions cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,9 @@ def inference(self,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding):
embedding,
required_cache_size=0,
flow_cache=None):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
Expand All @@ -134,13 +136,15 @@ def inference(self,

# mask = (~make_pad_mask(feat_len)).to(h)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder(
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10
n_timesteps=10,
required_cache_size=required_cache_size,
flow_cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat
return feat, flow_cache
21 changes: 18 additions & 3 deletions cosyvoice/flow/flow_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
self.estimator = estimator

@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, required_cache_size=0, flow_cache=None):
"""Forward diffusion
Args:
Expand All @@ -50,11 +50,26 @@ def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None):
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu) * temperature

if flow_cache is not None:
z_cache = flow_cache[0]
mu_cache = flow_cache[1]
z = torch.randn((mu.size(0), mu.size(1), mu.size(2) - z_cache.size(2)), dtype=mu.dtype, device=mu.device) * temperature
z = torch.cat((z_cache, z), dim=2) # [B, 80, T]
mu = torch.cat((mu_cache, mu[..., mu_cache.size(2):]), dim=2) # [B, 80, T]
else:
z = torch.randn_like(mu) * temperature

next_cache_start = max(z.size(2) - required_cache_size, 0)
flow_cache = [
z[..., next_cache_start:],
mu[..., next_cache_start:]
]

t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), flow_cache

def solve_euler(self, x, t_span, mu, mask, spks, cond):
"""
Expand Down

0 comments on commit 283e612

Please sign in to comment.