Skip to content

Commit

Permalink
Correct the collate_fn function to ensure that the targets are proper…
Browse files Browse the repository at this point in the history
…ly padded using the correct variable
  • Loading branch information
jshuadvd committed Jun 13, 2024
1 parent ef76867 commit 445983a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def collate_fn(batch):
[torch.tensor(seq) for seq in inputs], batch_first=True, padding_value=0
)
padded_targets = pad_sequence(
[torch.tensor(tgt) for seq in targets], batch_first=True, padding_value=-1
[torch.tensor(tgt) for tgt in targets], batch_first=True, padding_value=-1
)
return padded_inputs, padded_targets

Expand Down

0 comments on commit 445983a

Please sign in to comment.