Skip to content

Commit

Permalink
Merge pull request #436 from FunAudioLLM/dev/lyuxiang.lx
Browse files Browse the repository at this point in the history
Dev/lyuxiang.lx
  • Loading branch information
aluminumbox committed Sep 26, 2024
2 parents f6d44af + ba3d969 commit 6920200
Show file tree
Hide file tree
Showing 11 changed files with 59,190 additions and 36 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,9 @@ For `SenseVoice`, visit [SenseVoice repo](https://github.com/FunAudioLLM/SenseVo
- [ ] 25hz cosyvoice base model
- [ ] 25hz cosyvoice voice conversion model

- [ ] 2024/10

- [ ] 50hz llama based llm model which supports lora finetune

- [ ] TBD

- [ ] 25hz llama based llm model which supports lora finetune
- [ ] Support more instruction mode
- [ ] Voice conversion
- [ ] Music generation
Expand Down Expand Up @@ -74,6 +71,7 @@ If you are expert in this field, and you are only interested in training your ow
# SDK模型下载
from modelscope import snapshot_download
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-25Hz', local_dir='pretrained_models/CosyVoice-300M-25Hz')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
Expand All @@ -83,6 +81,7 @@ snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice
# git模型下载,请确保已安装git lfs
mkdir -p pretrained_models
git clone https://www.modelscope.cn/iic/CosyVoice-300M.git pretrained_models/CosyVoice-300M
git clone https://www.modelscope.cn/iic/CosyVoice-300M-25Hz.git pretrained_models/CosyVoice-300M-25Hz
git clone https://www.modelscope.cn/iic/CosyVoice-300M-SFT.git pretrained_models/CosyVoice-300M-SFT
git clone https://www.modelscope.cn/iic/CosyVoice-300M-Instruct.git pretrained_models/CosyVoice-300M-Instruct
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
Expand Down Expand Up @@ -121,7 +120,7 @@ print(cosyvoice.list_avaliable_spks())
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], 22050)

cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M')
cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-25Hz') # or change to pretrained_models/CosyVoice-300M for 50Hz inference
# zero_shot usage, <|zh|><|en|><|jp|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', prompt_speech_16k, stream=False)):
Expand All @@ -130,6 +129,11 @@ for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来
prompt_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.', prompt_speech_16k, stream=False)):
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], 22050)
# vc usage
prompt_speech_16k = load_wav('zero_shot_prompt.wav', 16000)
source_speech_16k = load_wav('cross_lingual_prompt.wav', 16000)
for i, j in enumerate(cosyvoice.inference_vc(source_speech_16k, prompt_speech_16k, stream=False)):
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], 22050)

cosyvoice = CosyVoice('pretrained_models/CosyVoice-300M-Instruct')
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
Expand Down
17 changes: 13 additions & 4 deletions cosyvoice/cli/cosyvoice.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
Expand All @@ -70,7 +70,7 @@ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=F
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
Expand All @@ -83,7 +83,7 @@ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False, spe
model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
Expand All @@ -97,8 +97,17 @@ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, spee
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.inference(**model_input, stream=stream, speed=speed):
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()

def inference_vc(self, source_speech_16k, prompt_speech_16k, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_speech_16k, prompt_speech_16k)
start_time = time.time()
for model_output in self.model.vc(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / 22050
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
14 changes: 14 additions & 0 deletions cosyvoice/cli/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ def __init__(self,
"CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device)
else:
self.spk2info = {}
self.instruct = instruct
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
Expand Down Expand Up @@ -172,3 +174,15 @@ def frontend_instruct(self, tts_text, spk_id, instruct_text):
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input

def frontend_vc(self, source_speech_16k, prompt_speech_16k):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_speech_16k)
prompt_speech_22050 = torchaudio.transforms.Resample(orig_freq=16000, new_freq=22050)(prompt_speech_16k)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_speech_22050)
embedding = self._extract_spk_embedding(prompt_speech_16k)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding}
return model_input
76 changes: 64 additions & 12 deletions cosyvoice/cli/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self,
self.token_max_hop_len = 200
self.token_overlap_len = 20
# mel fade in out
self.mel_overlap_len = 34
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
Expand Down Expand Up @@ -63,11 +63,11 @@ def load(self, llm_model, flow_model, hift_model):
self.hift.to(self.device).eval()

def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model)
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model)
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model)
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder

def load_onnx(self, flow_decoder_estimator_model):
Expand Down Expand Up @@ -131,11 +131,11 @@ def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech

def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
Expand All @@ -148,7 +148,8 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
Expand All @@ -164,7 +165,7 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
Expand All @@ -175,7 +176,58 @@ def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)

def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
Expand Down
5 changes: 2 additions & 3 deletions cosyvoice/flow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,15 +124,14 @@ def inference(self,
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / 50 * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)

# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)

# mask = (~make_pad_mask(feat_len)).to(h)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat = self.decoder(
mu=h.transpose(1, 2).contiguous(),
Expand Down
9 changes: 5 additions & 4 deletions cosyvoice/flow/length_regulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def forward(self, x, ylens=None):
olens = ylens
return out * mask, olens

def inference(self, x1, x2, mel_len1, mel_len2):
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
Expand Down
2 changes: 1 addition & 1 deletion cosyvoice/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def inference(
if top_ids == self.speech_token_size:
break
# in stream mode, yield token one by one
yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
yield top_ids
out_tokens.append(top_ids)
offset += lm_input.size(1)
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
Loading

0 comments on commit 6920200

Please sign in to comment.