From c66a4f48d4a06fbeae53dab016d6a7f069a65e84 Mon Sep 17 00:00:00 2001 From: TCord Date: Tue, 5 Dec 2023 11:11:20 +0100 Subject: [PATCH] streamline skip connection norm --- padertorch/contrib/je/modules/conv.py | 9 +++++---- padertorch/contrib/tcl/speaker_embeddings/dvectors.py | 3 +-- .../contrib/tcl/speaker_embeddings/student_embeddings.py | 3 +-- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/padertorch/contrib/je/modules/conv.py b/padertorch/contrib/je/modules/conv.py index 7eea5fde..31fff6f0 100644 --- a/padertorch/contrib/je/modules/conv.py +++ b/padertorch/contrib/je/modules/conv.py @@ -434,8 +434,7 @@ def __init__( pool_stride=None, return_pool_indices=False, return_state=False, - skip_connection_pre_activation=False, - skip_connection_norm=False + normalize_skip_convs=False, ): """ @@ -564,6 +563,7 @@ def __init__( for dst_idx in destination_indices: assert dst_idx > src_idx, (src_idx, dst_idx) if layer_in_channels[dst_idx] != layer_in_channels[src_idx]: + skip_norm = self.norm[src_idx if pre_activation else dst_idx] if normalize_skip_convs else None residual_skip_convs[f'{src_idx}->{dst_idx}'] = self.conv_cls( in_channels=layer_in_channels[src_idx], out_channels=layer_in_channels[dst_idx], @@ -572,9 +572,10 @@ def __init__( dilation=1, stride=1, pad_type=None, - norm=norm if skip_connection_norm else None, + norm=skip_norm, + norm_kwargs=None if skip_norm is None else norm_kwargs, activation_fn='identity', - pre_activation=skip_connection_pre_activation, + pre_activation=pre_activation, gated=False, ) self.residual_skip_convs = nn.ModuleDict(residual_skip_convs) diff --git a/padertorch/contrib/tcl/speaker_embeddings/dvectors.py b/padertorch/contrib/tcl/speaker_embeddings/dvectors.py index 233b3c14..dfae35e5 100644 --- a/padertorch/contrib/tcl/speaker_embeddings/dvectors.py +++ b/padertorch/contrib/tcl/speaker_embeddings/dvectors.py @@ -139,8 +139,7 @@ def __init__( activation_fn=activation_fn, pre_activation=pre_activation, norm=norm, - skip_connection_norm=True, - skip_connection_pre_activation=True + normalize_skip_convs=True ) self.output_convolution = Conv2d(channels[-1], dvec_dim, kernel_size=3, stride=(2, 1), bias=False, activation_fn='relu', norm=norm, pre_activation=True) diff --git a/padertorch/contrib/tcl/speaker_embeddings/student_embeddings.py b/padertorch/contrib/tcl/speaker_embeddings/student_embeddings.py index d5bf9a85..76b1d0cf 100644 --- a/padertorch/contrib/tcl/speaker_embeddings/student_embeddings.py +++ b/padertorch/contrib/tcl/speaker_embeddings/student_embeddings.py @@ -60,8 +60,7 @@ def __init__( activation_fn=activation_fn, pre_activation=pre_activation, norm=norm, - skip_connection_norm=True, - skip_connection_pre_activation=True + normalize_skip_convs=True ) self.output_convolution = Conv2d(channels[-1], dvec_dim * num_spk, kernel_size=3, stride=(2, 1), bias=False, activation_fn='relu', norm=norm, pre_activation=True)