Skip to content

Commit

Permalink
Remove notebooks from .gitignore
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jun 20, 2024
1 parent 867dded commit 6fbcfb7
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 1 deletion.
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,5 +162,4 @@ cython_debug/
.ipynb_checkpoints/
*.pyc
__pycache__/
notebooks/
__pycache__/
246 changes: 246 additions & 0 deletions notebooks/01_LongRoPE_training.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,246 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Train a LongRoPE model on a given dataset\n",
"from src.main import LongRoPEModel\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from torch.nn.utils.rnn import pad_sequence\n",
"import gzip\n",
"from transformers import GPT2Tokenizer\n",
"from importlib import reload\n",
"import src.main\n",
"reload(src.main)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"class CustomDataset(Dataset):\n",
" \"\"\"Custom dataset for handling sequences and targets.\"\"\"\n",
"\n",
" def __init__(self, sequences, targets):\n",
" self.sequences = sequences\n",
" self.targets = targets\n",
"\n",
" def __len__(self):\n",
" return len(self.sequences)\n",
"\n",
" def __getitem__(self, idx):\n",
" return self.sequences[idx], self.targets[idx]\n",
"\n",
"\n",
"def load_data(filename):\n",
" \"\"\"Load data from a gzip file.\"\"\"\n",
" with gzip.open(filename, \"rt\", encoding=\"utf-8\") as f:\n",
" data = f.read()\n",
" return data\n",
"\n",
"\n",
"def collate_fn(batch):\n",
" \"\"\"Custom collate function to pad data batches.\"\"\"\n",
" inputs, targets = zip(*batch)\n",
" padded_inputs = pad_sequence(\n",
" [torch.tensor(seq) for seq in inputs], batch_first=True, padding_value=0\n",
" )\n",
" padded_targets = pad_sequence(\n",
" [torch.tensor(tgt) for tgt in targets], batch_first=True, padding_value=-1\n",
" )\n",
" return padded_inputs, padded_targets\n",
"\n",
"def create_sliding_window_chunks(tokenized_data, max_length=65536, overlap=4096):\n",
" \"\"\"Create sliding window chunks from tokenized data.\"\"\"\n",
" sequences = []\n",
" start = 0\n",
" while start < len(tokenized_data):\n",
" end = start + max_length\n",
" if end >= len(tokenized_data):\n",
" # If the remaining sequence is shorter than max_length, append it as is\n",
" sequences.append(tokenized_data[start:])\n",
" else:\n",
" # Split the sequence into chunks of max_length with overlap\n",
" while start < end:\n",
" chunk_end = min(start + max_length, end)\n",
" sequences.append(tokenized_data[start:chunk_end])\n",
" start += max_length - overlap\n",
" return sequences\n",
"\n",
"def validate_targets(targets, vocab_size):\n",
" \"\"\"Validate that all target indices are within the vocabulary size.\"\"\"\n",
" for target_batch in targets:\n",
" if any(t >= vocab_size for t in target_batch):\n",
" raise ValueError(\"Target index out of vocabulary size range.\")\n",
" return True\n",
"\n",
"def preprocess_data(data, tokenizer, max_length, overlap):\n",
" \"\"\"\n",
" Preprocess the input data by tokenizing it in chunks and creating sliding window sequences.\n",
"\n",
" Args:\n",
" data (str): Input data as a string.\n",
" tokenizer: Tokenizer object for encoding the data.\n",
" max_length (int): Maximum sequence length for each chunk.\n",
" overlap (int): Overlap size between consecutive chunks.\n",
"\n",
" Returns:\n",
" list: List of preprocessed sequences.\n",
" \"\"\"\n",
" sequences = []\n",
" start = 0\n",
" while start < len(data):\n",
" end = start + max_length\n",
" chunk = data[start:end]\n",
" tokenized_chunk = tokenizer.encode(chunk)\n",
" \n",
" # Create sliding window sequences from the tokenized chunk\n",
" chunk_sequences = create_sliding_window_chunks(\n",
" tokenized_chunk, max_length=max_length, overlap=overlap\n",
" )\n",
" sequences.extend(chunk_sequences)\n",
" \n",
" start = end - overlap\n",
"\n",
" return sequences \n",
"\n",
"\n",
"def train(model, train_loader, val_loader, optimizer, criterion, device, epochs=10):\n",
" \"\"\"Training loop for the model.\"\"\"\n",
" model.train()\n",
" for epoch in range(epochs):\n",
" for inputs, targets in train_loader:\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
"\n",
" print(f\"Input shape: {inputs.shape}\")\n",
" print(f\"Target shape: {targets.shape}\")\n",
"\n",
" if inputs.size(1) > model.rope.max_len:\n",
" print(\n",
" f\"Warning: Batch with input size {inputs.size(1)} exceeds the maximum length of {model.rope.max_len}.\"\n",
" )\n",
" continue # Skip this batch\n",
"\n",
" optimizer.zero_grad()\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.permute(0, 2, 1), targets)\n",
" loss.backward()\n",
" optimizer.step()\n",
"\n",
" # Validation step\n",
" model.eval()\n",
" val_loss = 0\n",
" with torch.no_grad():\n",
" for inputs, targets in val_loader:\n",
" inputs, targets = inputs.to(device), targets.to(device)\n",
" outputs = model(inputs)\n",
" loss = criterion(outputs.permute(0, 2, 1), targets)\n",
" val_loss += loss.item()\n",
" print(\n",
" f\"Epoch {epoch+1}, Training Loss: {loss.item()}, Validation Loss: {val_loss / len(val_loader)}\"\n",
" )\n",
" model.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def main():\n",
" \"\"\"Main function to setup and run training.\"\"\"\n",
" tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
" tokenizer.model_max_length = 2048000 # Set the maximum sequence length for the tokenizer\n",
" data = load_data(\"../data/raw/enwik8.gz\")\n",
"\n",
" max_length = 65536\n",
" overlap = 4096\n",
" sequences = preprocess_data(data, tokenizer, max_length, overlap)\n",
"\n",
" targets = [seq[1:] + [tokenizer.eos_token_id] for seq in sequences]\n",
"\n",
" validate_targets(targets, tokenizer.vocab_size)\n",
"\n",
" print(f\"Validated: {validate_targets(targets, tokenizer.vocab_size)}\")\n",
"\n",
" dataset = CustomDataset(sequences, targets)\n",
" train_size = int(0.8 * len(dataset))\n",
" val_size = len(dataset) - train_size\n",
" train_dataset, val_dataset = torch.utils.data.random_split(\n",
" dataset, [train_size, val_size]\n",
" )\n",
"\n",
" train_loader = DataLoader(\n",
" train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn\n",
" )\n",
" val_loader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)\n",
"\n",
" device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
" \n",
" model = LongRoPEModel(\n",
" d_model=4096,\n",
" n_heads=32,\n",
" num_layers=6,\n",
" vocab_size=tokenizer.vocab_size,\n",
" max_len=2048000, # Set max_len to 2048k tokens\n",
" ).to(device)\n",
"\n",
" extended_model = model.extend_context(\n",
" data_path=\"../data/raw/enwik8.gz\",\n",
" target_length=2048000, # Set target_length to 2048k tokens\n",
" max_sequence_length=65536,\n",
" tokenizer=tokenizer,\n",
" population_size=64,\n",
" num_mutations=16,\n",
" num_crossovers=16,\n",
" max_iterations=10,\n",
" )\n",
"\n",
" recovered_model = model.recover_short_context(\n",
" data_path=\"../data/raw/enwik8.gz\",\n",
" max_sequence_length=48192,\n",
" tokenizer=tokenizer,\n",
" )\n",
"\n",
" optimizer = optim.Adam(recovered_model.parameters(), lr=1e-4)\n",
" criterion = nn.CrossEntropyLoss()\n",
"\n",
" train(recovered_model, train_loader, val_loader, optimizer, criterion, device)\n",
"\n",
"\n",
"if __name__ == \"__main__\":\n",
" main()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
1 change: 1 addition & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def recover_short_context(self, data_path, max_sequence_length, tokenizer):
Returns:
LongRoPEModel: Recovered LongRoPE model.
"""

if tokenizer is None:
raise ValueError("Tokenizer is required for recovering short context.")

Expand Down

0 comments on commit 6fbcfb7

Please sign in to comment.