diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py index e9e226d1c..85d989e44 100644 --- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py +++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py @@ -25,6 +25,7 @@ # example2 import torchaudio +import os wav_file = os.path.join(model.model_path, "example/asr_example.wav") input_tensor, sample_rate = torchaudio.load(wav_file) input_tensor = input_tensor.mean(0) @@ -33,7 +34,7 @@ # example3 import soundfile -import os + wav_file = os.path.join(model.model_path, "example/asr_example.wav") speech, sample_rate = soundfile.read(wav_file) res = model.generate(input=[speech], batch_size_s=300, is_final=True) diff --git a/funasr/bin/train.py b/funasr/bin/train.py index d9d4d6241..8ea0c0db5 100644 --- a/funasr/bin/train.py +++ b/funasr/bin/train.py @@ -154,7 +154,7 @@ def main(**kwargs): if batch_sampler is not None: batch_sampler_class = tables.batch_sampler_classes.get(batch_sampler) batch_sampler = batch_sampler_class(dataset_tr, **kwargs.get("dataset_conf")) - batch_sampler_val = batch_sampler_class(dataset_tr, is_training=False, **kwargs.get("dataset_conf")) + batch_sampler_val = batch_sampler_class(dataset_val, is_training=False, **kwargs.get("dataset_conf")) dataloader_tr = torch.utils.data.DataLoader(dataset_tr, collate_fn=dataset_tr.collator, batch_sampler=batch_sampler, diff --git a/funasr/datasets/audio_datasets/samplers.py b/funasr/datasets/audio_datasets/samplers.py index 0d9309814..535df5d05 100644 --- a/funasr/datasets/audio_datasets/samplers.py +++ b/funasr/datasets/audio_datasets/samplers.py @@ -26,6 +26,8 @@ def __init__(self, dataset, self.max_token_length = kwargs.get("max_token_length", 5000) self.shuffle_idx = np.arange(self.total_samples) self.shuffle = shuffle and is_training + self.length_scale_source = kwargs.get("length_scale_source", 1.0) + def __len__(self): return (self.total_samples-1) // self.batch_size + 1 @@ -53,8 +55,10 @@ def __iter__(self): idx_map = self.shuffle_idx[idx] # prompt = self.dataset.indexed_dataset[idx_map]["prompt"] - sample_len_cur = self.dataset.get_source_len(idx_map) + \ - self.dataset.get_target_len(idx_map) + target_len = self.dataset.get_target_len(idx_map) if self.batch_type == 'length' else 0.0 + source_len = self.dataset.get_source_len(idx_map) / self.length_scale_source + sample_len_cur = source_len + target_len + datalen_with_index.append([idx, sample_len_cur]) @@ -66,7 +70,7 @@ def __iter__(self): max_token_cur = max(max_token, sample_len_cur_raw) max_token_padding = 1 + num_sample - if self.batch_type == 'length': + if self.batch_type != 'example': max_token_padding *= max_token_cur if max_token_padding <= self.batch_size: batch.append(idx) diff --git a/funasr/models/whisper/model.py b/funasr/models/whisper/model.py index 6a6d47c6e..381a50161 100644 --- a/funasr/models/whisper/model.py +++ b/funasr/models/whisper/model.py @@ -10,6 +10,8 @@ from funasr.models.whisper.utils.decoding import detect_language as detect_language_function, decode as decode_function +from funasr.register import tables + @dataclass class ModelDimensions: @@ -128,6 +130,8 @@ def forward( return x + +@tables.register("encoder_classes", "WhisperEncoder") class AudioEncoder(nn.Module): def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__() @@ -158,7 +162,7 @@ def forward(self, x: Tensor): x = self.ln_post(x) return x - +@tables.register("decoder_classes", "WhisperDecoder") class TextDecoder(nn.Module): def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int): super().__init__() @@ -193,7 +197,7 @@ def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None): return logits - +@tables.register("model_classes", "Whisper") class Whisper(nn.Module): def __init__(self, dims: dict): super().__init__()