From def4bf17b9dc89ba5f5e683cf2c91597a758d731 Mon Sep 17 00:00:00 2001 From: Joshua David Date: Thu, 13 Jun 2024 21:50:31 -0700 Subject: [PATCH] update max_legn initialization in model --- src/main.py | 1 + train.py | 1 + 2 files changed, 2 insertions(+) diff --git a/src/main.py b/src/main.py index 4cc0499..82495bc 100644 --- a/src/main.py +++ b/src/main.py @@ -144,6 +144,7 @@ def __init__(self, d_model, n_heads, num_layers, vocab_size, max_len): self.n_heads = n_heads self.num_layers = num_layers self.vocab_size = vocab_size + self.max_len = max_len self.embedding = nn.Embedding(vocab_size, d_model) self.rope = RoPEPositionalEncoding(d_model, max_len) self.transformers = nn.ModuleList( diff --git a/train.py b/train.py index ac85d92..91bb1a8 100644 --- a/train.py +++ b/train.py @@ -8,6 +8,7 @@ from torch.nn.utils.rnn import pad_sequence import gzip from transformers import GPT2Tokenizer +from datasets import load_dataset from importlib import reload import src.main from accelerate import Accelerator