def collate_fn(batch): X = [] Y = [] for item in batch: split_index = torch.where(item == SEP_TOKEN_ID)[0][0] X.append(item[: split_index + 1]) Y.append(item[split_index + 1 :]) fin_X = [] fin_Y = [] for x, y in zip(X, Y): for i in range(len(y) - 1): fin_X.append(torch.cat((x, y[:i]))) fin_Y.append(y[i:]) return torch.nn.utils.rnn.pad_sequence(fin_X), torch.nn.utils.rnn.pad_sequence( fin_Y )
-
Notifications
You must be signed in to change notification settings - Fork 0
License
mayura-ai/sarika
Folders and files
Name | Name | Last commit message | Last commit date | |
---|---|---|---|---|
Repository files navigation
About
No description, website, or topics provided.
Resources
License
Stars
Watchers
Forks
Releases
No releases published
Packages 0
No packages published