Skip to content

Commit

Permalink
linter changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Pherkel committed Sep 18, 2023
1 parent eaec9a8 commit d9421ca
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
2 changes: 1 addition & 1 deletion swr2_asr/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def main(config_path: str, file_path: str, target_path: Union[str, None] = None)
target = target.replace("!", "")

print("---------")
print(f"Prediction:\n\{preds}")
print(f"Prediction:\n{preds}")
print("---------")
print(f"Target:\n{target}")
print("---------")
Expand Down
8 changes: 4 additions & 4 deletions swr2_asr/model_deep_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""

def __init__(self, n_feats):
super(CNNLayerNorm, self).__init__()
super().__init__()
self.layer_norm = nn.LayerNorm(n_feats)

def forward(self, data):
Expand All @@ -27,7 +27,7 @@ class ResidualCNN(nn.Module):
"""

def __init__(self, in_channels, out_channels, kernel, stride, dropout, n_feats):
super(ResidualCNN, self).__init__()
super().__init__()

self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
self.cnn2 = nn.Conv2d(out_channels, out_channels, kernel, stride, padding=kernel // 2)
Expand Down Expand Up @@ -55,7 +55,7 @@ class BidirectionalGRU(nn.Module):
"""Bidirectional GRU layer"""

def __init__(self, rnn_dim, hidden_size, dropout, batch_first):
super(BidirectionalGRU, self).__init__()
super().__init__()

self.BiGRU = nn.GRU( # pylint: disable=invalid-name
input_size=rnn_dim,
Expand All @@ -82,7 +82,7 @@ class SpeechRecognitionModel(nn.Module):
def __init__(
self, n_cnn_layers, n_rnn_layers, rnn_dim, n_class, n_feats, stride=2, dropout=0.1
):
super(SpeechRecognitionModel, self).__init__()
super().__init__()
n_feats = n_feats // 2
self.cnn = nn.Conv2d(
1, 32, 3, stride=stride, padding=3 // 2
Expand Down
4 changes: 3 additions & 1 deletion swr2_asr/utils/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def plot(path):
epoch = 5
while True:
try:
current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
current_state = torch.load(
path + str(epoch), map_location=torch.device("cpu")
) # pylint: disable=no-member
except FileNotFoundError:
break
train_losses.append((epoch, current_state["train_loss"].item()))
Expand Down

0 comments on commit d9421ca

Please sign in to comment.