diff --git a/.gitignore b/.gitignore index 8f41c54..5ada1eb 100644 --- a/.gitignore +++ b/.gitignore @@ -162,5 +162,4 @@ cython_debug/ .ipynb_checkpoints/ *.pyc __pycache__/ -notebooks/ __pycache__/ diff --git a/notebooks/01_LongRoPE_training.ipynb b/notebooks/01_LongRoPE_training.ipynb new file mode 100644 index 0000000..596d3ff --- /dev/null +++ b/notebooks/01_LongRoPE_training.ipynb @@ -0,0 +1,246 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Train a LongRoPE model on a given dataset\n", + "from src.main import LongRoPEModel\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "from torch.utils.data import DataLoader, Dataset\n", + "from torch.nn.utils.rnn import pad_sequence\n", + "import gzip\n", + "from transformers import GPT2Tokenizer\n", + "from importlib import reload\n", + "import src.main\n", + "reload(src.main)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomDataset(Dataset):\n", + " \"\"\"Custom dataset for handling sequences and targets.\"\"\"\n", + "\n", + " def __init__(self, sequences, targets):\n", + " self.sequences = sequences\n", + " self.targets = targets\n", + "\n", + " def __len__(self):\n", + " return len(self.sequences)\n", + "\n", + " def __getitem__(self, idx):\n", + " return self.sequences[idx], self.targets[idx]\n", + "\n", + "\n", + "def load_data(filename):\n", + " \"\"\"Load data from a gzip file.\"\"\"\n", + " with gzip.open(filename, \"rt\", encoding=\"utf-8\") as f:\n", + " data = f.read()\n", + " return data\n", + "\n", + "\n", + "def collate_fn(batch):\n", + " \"\"\"Custom collate function to pad data batches.\"\"\"\n", + " inputs, targets = zip(*batch)\n", + " padded_inputs = pad_sequence(\n", + " [torch.tensor(seq) for seq in inputs], batch_first=True, padding_value=0\n", + " )\n", + " padded_targets = pad_sequence(\n", + " [torch.tensor(tgt) for tgt in targets], batch_first=True, padding_value=-1\n", + " )\n", + " return padded_inputs, padded_targets\n", + "\n", + "def create_sliding_window_chunks(tokenized_data, max_length=65536, overlap=4096):\n", + " \"\"\"Create sliding window chunks from tokenized data.\"\"\"\n", + " sequences = []\n", + " start = 0\n", + " while start < len(tokenized_data):\n", + " end = start + max_length\n", + " if end >= len(tokenized_data):\n", + " # If the remaining sequence is shorter than max_length, append it as is\n", + " sequences.append(tokenized_data[start:])\n", + " else:\n", + " # Split the sequence into chunks of max_length with overlap\n", + " while start < end:\n", + " chunk_end = min(start + max_length, end)\n", + " sequences.append(tokenized_data[start:chunk_end])\n", + " start += max_length - overlap\n", + " return sequences\n", + "\n", + "def validate_targets(targets, vocab_size):\n", + " \"\"\"Validate that all target indices are within the vocabulary size.\"\"\"\n", + " for target_batch in targets:\n", + " if any(t >= vocab_size for t in target_batch):\n", + " raise ValueError(\"Target index out of vocabulary size range.\")\n", + " return True\n", + "\n", + "def preprocess_data(data, tokenizer, max_length, overlap):\n", + " \"\"\"\n", + " Preprocess the input data by tokenizing it in chunks and creating sliding window sequences.\n", + "\n", + " Args:\n", + " data (str): Input data as a string.\n", + " tokenizer: Tokenizer object for encoding the data.\n", + " max_length (int): Maximum sequence length for each chunk.\n", + " overlap (int): Overlap size between consecutive chunks.\n", + "\n", + " Returns:\n", + " list: List of preprocessed sequences.\n", + " \"\"\"\n", + " sequences = []\n", + " start = 0\n", + " while start < len(data):\n", + " end = start + max_length\n", + " chunk = data[start:end]\n", + " tokenized_chunk = tokenizer.encode(chunk)\n", + " \n", + " # Create sliding window sequences from the tokenized chunk\n", + " chunk_sequences = create_sliding_window_chunks(\n", + " tokenized_chunk, max_length=max_length, overlap=overlap\n", + " )\n", + " sequences.extend(chunk_sequences)\n", + " \n", + " start = end - overlap\n", + "\n", + " return sequences \n", + "\n", + "\n", + "def train(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):\n", + " \"\"\"Training loop for the model.\"\"\"\n", + " model.train()\n", + " for epoch in range(epochs):\n", + " for inputs, targets in train_loader:\n", + " inputs, targets = inputs.to(device), targets.to(device)\n", + "\n", + " print(f\"Input shape: {inputs.shape}\")\n", + " print(f\"Target shape: {targets.shape}\")\n", + "\n", + " if inputs.size(1) > model.rope.max_len:\n", + " print(\n", + " f\"Warning: Batch with input size {inputs.size(1)} exceeds the maximum length of {model.rope.max_len}.\"\n", + " )\n", + " continue # Skip this batch\n", + "\n", + " optimizer.zero_grad()\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs.permute(0, 2, 1), targets)\n", + " loss.backward()\n", + " optimizer.step()\n", + "\n", + " # Validation step\n", + " model.eval()\n", + " val_loss = 0\n", + " with torch.no_grad():\n", + " for inputs, targets in val_loader:\n", + " inputs, targets = inputs.to(device), targets.to(device)\n", + " outputs = model(inputs)\n", + " loss = criterion(outputs.permute(0, 2, 1), targets)\n", + " val_loss += loss.item()\n", + " print(\n", + " f\"Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss / len(val_loader)}\"\n", + " )\n", + " model.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def main():\n", + " \"\"\"Main function to setup and run training.\"\"\"\n", + " tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n", + " tokenizer.model_max_length = 2048000 # Set the maximum sequence length for the tokenizer\n", + " data = load_data(\"../data/raw/enwik8.gz\")\n", + "\n", + " max_length = 65536\n", + " overlap = 4096\n", + " sequences = preprocess_data(data, tokenizer, max_length, overlap)\n", + "\n", + " targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]\n", + "\n", + " validate_targets(targets, tokenizer.vocab_size)\n", + "\n", + " print(f\"Validated: {validate_targets(targets, tokenizer.vocab_size)}\")\n", + "\n", + " dataset = CustomDataset(sequences, targets)\n", + " train_size = int(0.8 * len(dataset))\n", + " val_size = len(dataset) - train_size\n", + " train_dataset, val_dataset = torch.utils.data.random_split(\n", + " dataset, [train_size, val_size]\n", + " )\n", + "\n", + " train_loader = DataLoader(\n", + " train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn\n", + " )\n", + " val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)\n", + "\n", + " device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + " \n", + " model = LongRoPEModel(\n", + " d_model=4096,\n", + " n_heads=32,\n", + " num_layers=6,\n", + " vocab_size=tokenizer.vocab_size,\n", + " max_len=2048000, # Set max_len to 2048k tokens\n", + " ).to(device)\n", + "\n", + " extended_model = model.extend_context(\n", + " data_path=\"../data/raw/enwik8.gz\",\n", + " target_length=2048000, # Set target_length to 2048k tokens\n", + " max_sequence_length=65536,\n", + " tokenizer=tokenizer,\n", + " population_size=64,\n", + " num_mutations=16,\n", + " num_crossovers=16,\n", + " max_iterations=10,\n", + " )\n", + "\n", + " recovered_model = model.recover_short_context(\n", + " data_path=\"../data/raw/enwik8.gz\",\n", + " max_sequence_length=48192,\n", + " tokenizer=tokenizer,\n", + " )\n", + "\n", + " optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)\n", + " criterion = nn.CrossEntropyLoss()\n", + "\n", + " train(recovered_model, train_loader, val_loader, optimizer, criterion, device)\n", + "\n", + "\n", + "if __name__ == \"__main__\":\n", + " main()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/main.py b/src/main.py index c508380..857efb9 100644 --- a/src/main.py +++ b/src/main.py @@ -268,6 +268,7 @@ def recover_short_context(self, data_path, max_sequence_length, tokenizer): Returns: LongRoPEModel: Recovered LongRoPE model. """ + if tokenizer is None: raise ValueError("Tokenizer is required for recovering short context.")