Skip to content

Commit

Permalink
streamline skip connection norm
Browse files Browse the repository at this point in the history
  • Loading branch information
TCord committed Dec 5, 2023
1 parent b074b02 commit c66a4f4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 8 deletions.
9 changes: 5 additions & 4 deletions padertorch/contrib/je/modules/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,8 +434,7 @@ def __init__(
pool_stride=None,
return_pool_indices=False,
return_state=False,
skip_connection_pre_activation=False,
skip_connection_norm=False
normalize_skip_convs=False,
):
"""
Expand Down Expand Up @@ -564,6 +563,7 @@ def __init__(
for dst_idx in destination_indices:
assert dst_idx > src_idx, (src_idx, dst_idx)
if layer_in_channels[dst_idx] != layer_in_channels[src_idx]:
skip_norm = self.norm[src_idx if pre_activation else dst_idx] if normalize_skip_convs else None
residual_skip_convs[f'{src_idx}->{dst_idx}'] = self.conv_cls(
in_channels=layer_in_channels[src_idx],
out_channels=layer_in_channels[dst_idx],
Expand All @@ -572,9 +572,10 @@ def __init__(
dilation=1,
stride=1,
pad_type=None,
norm=norm if skip_connection_norm else None,
norm=skip_norm,
norm_kwargs=None if skip_norm is None else norm_kwargs,
activation_fn='identity',
pre_activation=skip_connection_pre_activation,
pre_activation=pre_activation,
gated=False,
)
self.residual_skip_convs = nn.ModuleDict(residual_skip_convs)
Expand Down
3 changes: 1 addition & 2 deletions padertorch/contrib/tcl/speaker_embeddings/dvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,7 @@ def __init__(
activation_fn=activation_fn,
pre_activation=pre_activation,
norm=norm,
skip_connection_norm=True,
skip_connection_pre_activation=True
normalize_skip_convs=True
)
self.output_convolution = Conv2d(channels[-1], dvec_dim, kernel_size=3, stride=(2, 1), bias=False,
activation_fn='relu', norm=norm, pre_activation=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def __init__(
activation_fn=activation_fn,
pre_activation=pre_activation,
norm=norm,
skip_connection_norm=True,
skip_connection_pre_activation=True
normalize_skip_convs=True
)
self.output_convolution = Conv2d(channels[-1], dvec_dim * num_spk, kernel_size=3, stride=(2, 1), bias=False,
activation_fn='relu', norm=norm, pre_activation=True)
Expand Down

0 comments on commit c66a4f4

Please sign in to comment.