From 31cd71d5e9245cbad9d0fcf3029fca26923d7646 Mon Sep 17 00:00:00 2001 From: Dinghao Zhou Date: Mon, 12 Aug 2024 16:10:01 +0800 Subject: [PATCH] [ssl] bestrq make output bias to false --- wenet/ssl/bestrq/bestrq_model.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/wenet/ssl/bestrq/bestrq_model.py b/wenet/ssl/bestrq/bestrq_model.py index 60f20105e..d75bc6de9 100644 --- a/wenet/ssl/bestrq/bestrq_model.py +++ b/wenet/ssl/bestrq/bestrq_model.py @@ -67,6 +67,7 @@ def __init__( mask_length: int = 10, min_masks: int = 2, norm_epsilon: float = 1e-5, + out_bias: bool = False, features_regularization_weight: float = 0.01, ) -> None: super().__init__() @@ -86,9 +87,11 @@ def __init__( torch.empty(self.num_codebooks, self.encoder.output_size(), num_embeddings)) torch.nn.init.trunc_normal_(self.encoder_top_n_out, std=0.02) - self.encoder_top_n_out_bias = torch.nn.parameter.Parameter( - torch.empty(self.num_codebooks, num_embeddings)) - torch.nn.init.zeros_(self.encoder_top_n_out_bias) + self.out_bias = out_bias + if self.out_bias: + self.encoder_top_n_out_bias = torch.nn.parameter.Parameter( + torch.empty(self.num_codebooks, num_embeddings)) + torch.nn.init.zeros_(self.encoder_top_n_out_bias) # stack input: eg: fbank self.stack_frames = self.encoder.embed.right_context + 1 @@ -189,7 +192,8 @@ def forward( 0) # [1, num_codebooks, dim, num_embeddings] out = torch.matmul(out, top_n_out) # [B, num_codebooks, T', num_embeddings] - out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) + if self.out_bias: + out = out + self.encoder_top_n_out_bias.unsqueeze(0).unsqueeze(2) # 5 compute loss masks = out_mask.squeeze(1) * code_ids_mask