-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
247 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,5 +162,4 @@ cython_debug/ | |
.ipynb_checkpoints/ | ||
*.pyc | ||
__pycache__/ | ||
notebooks/ | ||
__pycache__/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,246 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Train a LongRoPE model on a given dataset\n", | ||
"from src.main import LongRoPEModel\n", | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"import torch.optim as optim\n", | ||
"from torch.utils.data import DataLoader, Dataset\n", | ||
"from torch.nn.utils.rnn import pad_sequence\n", | ||
"import gzip\n", | ||
"from transformers import GPT2Tokenizer\n", | ||
"from importlib import reload\n", | ||
"import src.main\n", | ||
"reload(src.main)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"class CustomDataset(Dataset):\n", | ||
" \"\"\"Custom dataset for handling sequences and targets.\"\"\"\n", | ||
"\n", | ||
" def __init__(self, sequences, targets):\n", | ||
" self.sequences = sequences\n", | ||
" self.targets = targets\n", | ||
"\n", | ||
" def __len__(self):\n", | ||
" return len(self.sequences)\n", | ||
"\n", | ||
" def __getitem__(self, idx):\n", | ||
" return self.sequences[idx], self.targets[idx]\n", | ||
"\n", | ||
"\n", | ||
"def load_data(filename):\n", | ||
" \"\"\"Load data from a gzip file.\"\"\"\n", | ||
" with gzip.open(filename, \"rt\", encoding=\"utf-8\") as f:\n", | ||
" data = f.read()\n", | ||
" return data\n", | ||
"\n", | ||
"\n", | ||
"def collate_fn(batch):\n", | ||
" \"\"\"Custom collate function to pad data batches.\"\"\"\n", | ||
" inputs, targets = zip(*batch)\n", | ||
" padded_inputs = pad_sequence(\n", | ||
" [torch.tensor(seq) for seq in inputs], batch_first=True, padding_value=0\n", | ||
" )\n", | ||
" padded_targets = pad_sequence(\n", | ||
" [torch.tensor(tgt) for tgt in targets], batch_first=True, padding_value=-1\n", | ||
" )\n", | ||
" return padded_inputs, padded_targets\n", | ||
"\n", | ||
"def create_sliding_window_chunks(tokenized_data, max_length=65536, overlap=4096):\n", | ||
" \"\"\"Create sliding window chunks from tokenized data.\"\"\"\n", | ||
" sequences = []\n", | ||
" start = 0\n", | ||
" while start < len(tokenized_data):\n", | ||
" end = start + max_length\n", | ||
" if end >= len(tokenized_data):\n", | ||
" # If the remaining sequence is shorter than max_length, append it as is\n", | ||
" sequences.append(tokenized_data[start:])\n", | ||
" else:\n", | ||
" # Split the sequence into chunks of max_length with overlap\n", | ||
" while start < end:\n", | ||
" chunk_end = min(start + max_length, end)\n", | ||
" sequences.append(tokenized_data[start:chunk_end])\n", | ||
" start += max_length - overlap\n", | ||
" return sequences\n", | ||
"\n", | ||
"def validate_targets(targets, vocab_size):\n", | ||
" \"\"\"Validate that all target indices are within the vocabulary size.\"\"\"\n", | ||
" for target_batch in targets:\n", | ||
" if any(t >= vocab_size for t in target_batch):\n", | ||
" raise ValueError(\"Target index out of vocabulary size range.\")\n", | ||
" return True\n", | ||
"\n", | ||
"def preprocess_data(data, tokenizer, max_length, overlap):\n", | ||
" \"\"\"\n", | ||
" Preprocess the input data by tokenizing it in chunks and creating sliding window sequences.\n", | ||
"\n", | ||
" Args:\n", | ||
" data (str): Input data as a string.\n", | ||
" tokenizer: Tokenizer object for encoding the data.\n", | ||
" max_length (int): Maximum sequence length for each chunk.\n", | ||
" overlap (int): Overlap size between consecutive chunks.\n", | ||
"\n", | ||
" Returns:\n", | ||
" list: List of preprocessed sequences.\n", | ||
" \"\"\"\n", | ||
" sequences = []\n", | ||
" start = 0\n", | ||
" while start < len(data):\n", | ||
" end = start + max_length\n", | ||
" chunk = data[start:end]\n", | ||
" tokenized_chunk = tokenizer.encode(chunk)\n", | ||
" \n", | ||
" # Create sliding window sequences from the tokenized chunk\n", | ||
" chunk_sequences = create_sliding_window_chunks(\n", | ||
" tokenized_chunk, max_length=max_length, overlap=overlap\n", | ||
" )\n", | ||
" sequences.extend(chunk_sequences)\n", | ||
" \n", | ||
" start = end - overlap\n", | ||
"\n", | ||
" return sequences \n", | ||
"\n", | ||
"\n", | ||
"def train(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):\n", | ||
" \"\"\"Training loop for the model.\"\"\"\n", | ||
" model.train()\n", | ||
" for epoch in range(epochs):\n", | ||
" for inputs, targets in train_loader:\n", | ||
" inputs, targets = inputs.to(device), targets.to(device)\n", | ||
"\n", | ||
" print(f\"Input shape: {inputs.shape}\")\n", | ||
" print(f\"Target shape: {targets.shape}\")\n", | ||
"\n", | ||
" if inputs.size(1) > model.rope.max_len:\n", | ||
" print(\n", | ||
" f\"Warning: Batch with input size {inputs.size(1)} exceeds the maximum length of {model.rope.max_len}.\"\n", | ||
" )\n", | ||
" continue # Skip this batch\n", | ||
"\n", | ||
" optimizer.zero_grad()\n", | ||
" outputs = model(inputs)\n", | ||
" loss = criterion(outputs.permute(0, 2, 1), targets)\n", | ||
" loss.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" # Validation step\n", | ||
" model.eval()\n", | ||
" val_loss = 0\n", | ||
" with torch.no_grad():\n", | ||
" for inputs, targets in val_loader:\n", | ||
" inputs, targets = inputs.to(device), targets.to(device)\n", | ||
" outputs = model(inputs)\n", | ||
" loss = criterion(outputs.permute(0, 2, 1), targets)\n", | ||
" val_loss += loss.item()\n", | ||
" print(\n", | ||
" f\"Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss / len(val_loader)}\"\n", | ||
" )\n", | ||
" model.train()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def main():\n", | ||
" \"\"\"Main function to setup and run training.\"\"\"\n", | ||
" tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", | ||
" tokenizer.model_max_length = 2048000 # Set the maximum sequence length for the tokenizer\n", | ||
" data = load_data(\"../data/raw/enwik8.gz\")\n", | ||
"\n", | ||
" max_length = 65536\n", | ||
" overlap = 4096\n", | ||
" sequences = preprocess_data(data, tokenizer, max_length, overlap)\n", | ||
"\n", | ||
" targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]\n", | ||
"\n", | ||
" validate_targets(targets, tokenizer.vocab_size)\n", | ||
"\n", | ||
" print(f\"Validated: {validate_targets(targets, tokenizer.vocab_size)}\")\n", | ||
"\n", | ||
" dataset = CustomDataset(sequences, targets)\n", | ||
" train_size = int(0.8 * len(dataset))\n", | ||
" val_size = len(dataset) - train_size\n", | ||
" train_dataset, val_dataset = torch.utils.data.random_split(\n", | ||
" dataset, [train_size, val_size]\n", | ||
" )\n", | ||
"\n", | ||
" train_loader = DataLoader(\n", | ||
" train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn\n", | ||
" )\n", | ||
" val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)\n", | ||
"\n", | ||
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
" \n", | ||
" model = LongRoPEModel(\n", | ||
" d_model=4096,\n", | ||
" n_heads=32,\n", | ||
" num_layers=6,\n", | ||
" vocab_size=tokenizer.vocab_size,\n", | ||
" max_len=2048000, # Set max_len to 2048k tokens\n", | ||
" ).to(device)\n", | ||
"\n", | ||
" extended_model = model.extend_context(\n", | ||
" data_path=\"../data/raw/enwik8.gz\",\n", | ||
" target_length=2048000, # Set target_length to 2048k tokens\n", | ||
" max_sequence_length=65536,\n", | ||
" tokenizer=tokenizer,\n", | ||
" population_size=64,\n", | ||
" num_mutations=16,\n", | ||
" num_crossovers=16,\n", | ||
" max_iterations=10,\n", | ||
" )\n", | ||
"\n", | ||
" recovered_model = model.recover_short_context(\n", | ||
" data_path=\"../data/raw/enwik8.gz\",\n", | ||
" max_sequence_length=48192,\n", | ||
" tokenizer=tokenizer,\n", | ||
" )\n", | ||
"\n", | ||
" optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)\n", | ||
" criterion = nn.CrossEntropyLoss()\n", | ||
"\n", | ||
" train(recovered_model, train_loader, val_loader, optimizer, criterion, device)\n", | ||
"\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" main()" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters