diff --git a/base/AIVA_500m.ipynb b/base/AIVA_500m.ipynb
index f633c80..bad4cb5 100644
--- a/base/AIVA_500m.ipynb
+++ b/base/AIVA_500m.ipynb
@@ -4,7 +4,9 @@
"metadata": {
"colab": {
"provenance": [],
- "authorship_tag": "ABX9TyMgvNpTgHlkB38JeUIUdd7l",
+ "machine_shape": "hm",
+ "gpuType": "T4",
+ "authorship_tag": "ABX9TyOeYX5zp+reGmNxsWXca/e6",
"include_colab_link": true
},
"kernelspec": {
@@ -13,7 +15,8 @@
},
"language_info": {
"name": "python"
- }
+ },
+ "accelerator": "GPU"
},
"cells": [
{
@@ -23,16 +26,28 @@
"colab_type": "text"
},
"source": [
- ""
+ ""
]
},
{
"cell_type": "code",
- "execution_count": null,
+ "execution_count": 1,
"metadata": {
- "id": "owjS3-sBcdBh"
+ "id": "owjS3-sBcdBh",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "7986e6e4-75e4-4f10-e868-dfbda7a0d3e7"
},
- "outputs": [],
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Mounted at /content/drive\n"
+ ]
+ }
+ ],
"source": [
"from google.colab import drive\n",
"drive.mount('/content/drive')"
@@ -44,25 +59,59 @@
"!pip install tiktoken"
],
"metadata": {
- "id": "keNA2G8xfroc"
+ "id": "keNA2G8xfroc",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "f3204533-15b0-4b94-df77-93925fa224b1"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Collecting tiktoken\n",
+ " Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)\n",
+ "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/1.8 MB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r\u001b[2K \u001b[91m━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.3/1.8 MB\u001b[0m \u001b[31m9.6 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[90m╺\u001b[0m\u001b[90m━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.1/1.8 MB\u001b[0m \u001b[31m16.0 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
+ "\u001b[?25hRequirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken) (2023.12.25)\n",
+ "Requirement already satisfied: requests>=2.26.0 in /usr/local/lib/python3.10/dist-packages (from tiktoken) (2.31.0)\n",
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken) (3.3.2)\n",
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken) (3.6)\n",
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken) (2.0.7)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.26.0->tiktoken) (2024.2.2)\n",
+ "Installing collected packages: tiktoken\n",
+ "Successfully installed tiktoken-0.6.0\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
"source": [
"# data for model\n",
- "with open('/content/drive/MyDrive/new_training_data.txt', 'r', encoding='utf-8') as file:\n",
- " captions = file.read()\n",
+ "with open('/content/drive/MyDrive/training data/consolidated_350m.txt', 'r', encoding='utf-8') as file:\n",
+ " train_data = file.read()\n",
"\n",
- "print(len(captions)/1e6, 'million words')"
+ "print(len(train_data)/1e6, 'million words')"
],
"metadata": {
- "id": "BSh3yuTGfu21"
+ "id": "BSh3yuTGfu21",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "86e17e89-11a4-45cf-bfab-6f68002ef9bc"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 3,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "2274.16219 million words\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
@@ -71,16 +120,15 @@
"tokenizer = tiktoken.get_encoding(\"p50k_base\")\n",
"tokenizer = tiktoken.encoding_for_model(\"text-davinci-003\")\n",
"\n",
- "input_data = tokenizer.encode(captions)\n",
- "end_time = timeit.default_timer()\n",
+ "input_data = tokenizer.encode(train_data)\n",
"\n",
"print(\"total tokens\", len(input_data)/1e6, 'million')\n",
- "print(f\"time taken to train the tokenizer {total_time}mins\")\n",
"\n",
"n = int(0.9*len(input_data)) # first 90% will be train, rest val\n",
"train_data = input_data[:n]\n",
"val_data = input_data[n:]\n",
- "print(f\"train data {len(train_data) / 1e6} million'\\n'validation data {len(val_data) / 1e6} million\")"
+ "\n",
+ "del input_data, n"
],
"metadata": {
"id": "VmBZRVhqfyn2"
@@ -97,243 +145,327 @@
"train_data = torch.tensor(train_data, dtype=torch.long)\n",
"val_data = torch.tensor(val_data, dtype=torch.long)\n",
"\n",
+ "print(f\"train data {(len(train_data) / 1e6):.0f} million\\nvalidation data {(len(val_data) / 1e6):.0f} million\")\n",
"print(f\"train data = {train_data[:10]}, \\nval data = {val_data[:10]}\")"
],
"metadata": {
- "id": "rtfFsfgdf8rw"
+ "id": "rtfFsfgdf8rw",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "9aaaea88-25d5-4a97-ace1-7c817fca7270"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 9,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":4: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " train_data = torch.tensor(train_data, dtype=torch.long)\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "train data 588 million\n",
+ "validation data 65 million\n",
+ "train data = tensor([ 3886, 25, 7443, 13, 785, 48073, 19433, 25, 2932, 860]), \n",
+ "val data = tensor([ 7579, 2885, 17941, 1847, 7446, 8696, 2389, 18310, 13, 3336])\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ ":5: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
+ " val_data = torch.tensor(val_data, dtype=torch.long)\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
"source": [
- "import torch.nn as nn\n",
- "from torch.nn import functional as F\n",
- "\n",
"# hyperparameters\n",
- "batch_size = 32\n",
- "block_size = 512\n",
+ "batch_size = 10\n",
+ "block_size = 256\n",
"max_iters = 1000\n",
"eval_interval = 100\n",
- "learning_rate = 3e-4\n",
+ "learning_rate = 3e-5\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
- "eval_iters = 50\n",
- "d_embd = 512\n",
- "n_head = 16\n",
- "n_layer = 16\n",
+ "eval_iters = 10\n",
+ "d_model = 512\n",
+ "n_head = 20\n",
+ "n_layers = 18\n",
"dropout = 0.2\n",
- "norm_eps = 1e-05\n",
- "# ------------\n",
- "\n",
- "torch.manual_seed(1400)\n",
- "\n",
- "# data loading\n",
- "def get_batch(split):\n",
- " # generate a small batch of data of inputs x and targets y\n",
- " data = train_data if split == 'train' else val_data\n",
- " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
- " x = torch.stack([data[i:i+block_size] for i in ix])\n",
- " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
- " x, y = x.to(device), y.to(device)\n",
- " return x, y\n",
+ "norm_eps = 1e-05"
+ ],
+ "metadata": {
+ "id": "tJuCsc1QPdts"
+ },
+ "execution_count": 10,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "from torch.nn import functional as F\n",
"\n",
- "@torch.no_grad()\n",
- "def estimate_loss():\n",
- " out = {}\n",
- " model.eval()\n",
- " for split in ['train', 'val']:\n",
- " losses = torch.zeros(eval_iters)\n",
- " for k in range(eval_iters):\n",
- " X, Y = get_batch(split)\n",
- " logits, loss = model(X, Y)\n",
- " losses[k] = loss.item()\n",
- " out[split] = losses.mean()\n",
- " model.train()\n",
- " return out\n",
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
"\n",
- "class Head(nn.Module):\n",
- " \"\"\" one head of self-attention \"\"\"\n",
+ "class RMSNorm(nn.Module):\n",
+ " def __init__(self, dim: int, eps: float = 1e-6):\n",
+ " \"\"\"\n",
+ " Initialize the RMSNorm normalization layer.\n",
+ " Args:\n",
+ " dim (int): The dimension of the input tensor.\n",
+ " eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.\n",
+ " Attributes:\n",
+ " eps (float): A small value added to the denominator for numerical stability.\n",
+ " weight (nn.Parameter): Learnable scaling parameter.\n",
+ " \"\"\"\n",
+ " super().__init__()\n",
+ " self.eps = eps\n",
+ " self.weight = nn.Parameter(torch.ones(dim))\n",
+ "\n",
+ " def _norm(self, x):\n",
+ " \"\"\"\n",
+ " Apply the RMSNorm normalization to the input tensor.\n",
+ " Args:\n",
+ " x (torch.Tensor): The input tensor.\n",
+ " Returns:\n",
+ " torch.Tensor: The normalized tensor.\n",
+ " \"\"\"\n",
+ " return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
"\n",
- " def __init__(self, d_embd, n_head, dropout, block_size):\n",
- " head_size = d_embd // n_head\n",
+ " def forward(self, x):\n",
+ " \"\"\"\n",
+ " Forward pass through the RMSNorm layer.\n",
+ " Args:\n",
+ " x (torch.Tensor): The input tensor.\n",
+ " Returns:\n",
+ " torch.Tensor: The output tensor after applying RMSNorm.\n",
+ " \"\"\"\n",
+ " output = self._norm(x.float()).type_as(x)\n",
+ " return output * self.weight\n",
+ "\n",
+ "class UnMaskedHead(nn.Module):\n",
+ " def __init__(self, head_size, d_model, block_size, dropout):\n",
" super().__init__()\n",
- " self.key = nn.Linear(d_embd, head_size, bias=True)\n",
- " self.query = nn.Linear(d_embd, head_size, bias=True)\n",
- " self.value = nn.Linear(d_embd, head_size, bias=True)\n",
- " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
+ " self.key = nn.Linear(d_model, head_size, bias=True)\n",
+ " self.query = nn.Linear(d_model, head_size, bias=True)\n",
+ " self.value = nn.Linear(d_model, head_size, bias=True)\n",
" self.dropout = nn.Dropout(dropout)\n",
+ " self.rel_pos_embd = nn.Parameter(torch.randn(block_size, block_size, head_size))\n",
"\n",
" def forward(self, x):\n",
- " B,T,C = x.shape\n",
- " key = self.key(x) # (B,T,hs)\n",
- " query = self.query(x) # (B,T,hs)\n",
- "\n",
- " # compute attention scores (\"affinities\")\n",
- " weights = query @ key.transpose(-2,-1) * key.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n",
- " weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n",
- " weights = F.softmax(weights, dim=-1) # (B, T, T)\n",
- " weights = self.dropout(weights)\n",
- "\n",
- " # perform the weighted aggregation of the values\n",
- " value = self.value(x) # (B,T,hs)\n",
- " out = weights @ value # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n",
- " return out\n",
- "class MultiHeadAttention(nn.Module):\n",
- " \"\"\" multiple heads of self-attention in parallel \"\"\"\n",
+ " B, T, C = x.shape\n",
+ " key = self.key(x)\n",
+ " query = self.query(x)\n",
+ "\n",
+ " scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
+ " rel_pos_scores = torch.einsum('btc,tvc->btv', query, self.rel_pos_embd[:T, :T])\n",
+ " scores = scores + rel_pos_scores\n",
+ "\n",
+ " att_mat = F.softmax(scores, dim=-1)\n",
+ " att_mat = self.dropout(att_mat)\n",
+ " value = self.value(x)\n",
+ " output = torch.matmul(att_mat, value)\n",
+ " return output\n",
"\n",
- " def __init__(self, d_embd, n_head, dropout, block_size):\n",
+ "class UnMaskedAttention(nn.Module):\n",
+ " def __init__(self, d_model, block_size, dropout, n_head):\n",
+ " head_size = d_model // n_head\n",
" super().__init__()\n",
- " self.heads = nn.ModuleList([Head(d_embd=d_embd, n_head=n_head, dropout=dropout, block_size=block_size) for _ in range(n_head)])\n",
- " self.proj = nn.Linear(n_head * (d_embd // n_head), d_embd)\n",
+ " self.heads = nn.ModuleList([UnMaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n",
+ " self.proj = nn.Linear(n_head * head_size, d_model)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
" def forward(self, x):\n",
" out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
- " out = self.dropout(out)\n",
- "\n",
+ " out = self.dropout(self.proj(out))\n",
" return out\n",
"\n",
- "class FeedForward:\n",
- " \"\"\" dual linear layer with GELU function \"\"\"\n",
- " def __init__(self, d_embd):\n",
+ "class MaskedHead(nn.Module):\n",
+ " def __init__(self, d_model, head_size, dropout, block_size):\n",
" super().__init__()\n",
- " self.fc1 = nn.Linear(d_embd, 4*d_embd) # n_ff = 4*d_embd\n",
- " self.fc2 = nn.Linear(4*d_embd, d_embd) # n_ff = 4*d_embd\n",
+ " self.key = nn.Linear(d_model, head_size, bias=False)\n",
+ " self.query = nn.Linear(d_model, head_size, bias=False)\n",
+ " self.value = nn.Linear(d_model, head_size, bias=False)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n",
"\n",
" def forward(self, x):\n",
- " x = F.gelu(self.fc1(x)) # GELU insted of ReLU\n",
- " x = self.fc2(x)\n",
- " return x\n",
+ " B, T, C = x.shape\n",
+ " key = self.key(x)\n",
+ " query = self.query(x)\n",
+ "\n",
+ " scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
+ " scores = scores.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n",
"\n",
- "class EncoderDecoderAttention(nn.Module):\n",
- " \"\"\" separate attention layer for decoder layer \"\"\"\n",
+ " att_mat = F.softmax(scores, dim=-1)\n",
+ " att_mat = self.dropout(att_mat)\n",
+ " value = self.value(x)\n",
+ " output = torch.matmul(att_mat, value)\n",
+ " return output\n",
"\n",
- " def __init__(self, d_embd, n_head, dropout, block_size):\n",
+ "class CasualMaskedAttention(nn.Module):\n",
+ " def __init__(self, d_model, block_size, dropout, n_head):\n",
+ " head_size = d_model // n_head\n",
" super().__init__()\n",
- " self.heads = nn.ModuleList([Head(d_embd, n_head, dropout, block_size) for _ in range(n_head)])\n",
- " self.proj = nn.Linear(n_head * (d_embd // n_head), d_embd)\n",
+ " self.heads = nn.ModuleList([MaskedHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n",
+ " self.proj = nn.Linear(n_head * head_size, d_model)\n",
" self.dropout = nn.Dropout(dropout)\n",
"\n",
- " def forward(self, query, key, value, mask=None):\n",
- " x = torch.cat((key, query, value), dim=-1)\n",
- " energies = []\n",
- " for head in self.heads:\n",
- " energy = head(x)\n",
- " energies.append(energy.unsqueeze(1))\n",
- " energy = torch.cat(energies, dim=1)\n",
- " energy = self.proj(energy)\n",
- " energy = self.dropout(energy)\n",
+ " def forward(self, x):\n",
+ " out = torch.cat([h(x) for h in self.heads], dim=-1)\n",
+ " out = self.dropout(self.proj(out))\n",
+ " return out\n",
+ "\n",
+ "class FinalHead(nn.Module):\n",
+ " def __init__(self, d_model, head_size, dropout, block_size):\n",
+ " super().__init__()\n",
+ " self.key = nn.Linear(d_model, head_size, bias=True)\n",
+ " self.query = nn.Linear(d_model, head_size, bias=True)\n",
+ " self.value = nn.Linear(d_model, head_size, bias=True)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
"\n",
- " if mask is not None:\n",
- " energy = energy.masked_fill(mask == 0, float('-inf'))\n",
+ " def forward(self, x, att):\n",
+ " B, T, C = x.shape\n",
+ " key = self.key(att)\n",
+ " query = self.query(att)\n",
"\n",
- " attention = F.softmax(energy, dim=-1)\n",
- " output = torch.matmul(attention, value)\n",
+ " scores = torch.matmul(query, key.transpose(-2, -1)) / (key.shape[-1] ** -0.5)\n",
"\n",
+ " att_mat = F.softmax(scores, dim=-1)\n",
+ " att_mat = self.dropout(att_mat)\n",
+ " value = self.value(x)\n",
+ " output = torch.matmul(att_mat, value)\n",
" return output\n",
"\n",
- "class EncoderLayer(nn.Module):\n",
- " \"\"\" Encoder Layer \"\"\"\n",
- "\n",
- " def __init__(self, d_embd, n_head, dropout, block_size):\n",
+ "class FinalAttention(nn.Module):\n",
+ " def __init__(self, d_model, block_size, dropout, n_head):\n",
+ " head_size = d_model // n_head\n",
" super().__init__()\n",
- " self.s_att = MultiHeadAttention(d_embd=d_embd, n_head=n_head, block_size=block_size, dropout=dropout)\n",
- " self.ffwd = FeedForward(d_embd=d_embd)\n",
+ " self.heads = nn.ModuleList([FinalHead(d_model=d_model, dropout=dropout, block_size=block_size, head_size=head_size) for _ in range(n_head)])\n",
+ " self.proj = nn.Linear(n_head * head_size, d_model)\n",
" self.dropout = nn.Dropout(dropout)\n",
- " self.norm1 = nn.LayerNorm(d_embd)\n",
- " self.norm2 = nn.LayerNorm(d_embd)\n",
"\n",
- " def forward(self, src, src_mask=None):\n",
- " src2 = self.s_att(src)\n",
- " src = src + self.dropout(src2)\n",
- " src = self.norm1(src)\n",
+ " def forward(self, x, att):\n",
+ " out = torch.cat([h(x, att) for h in self.heads], dim=-1)\n",
+ " out = self.dropout(self.proj(out))\n",
+ " return out\n",
+ "\n",
+ "class FeedForward(nn.Module):\n",
+ " def __init__(self, d_model, dropout):\n",
+ " super().__init__()\n",
+ " self.net = nn.Sequential(\n",
+ " nn.Linear(d_model, 10*d_model),\n",
+ " nn.GELU(),\n",
+ " nn.Linear(10*d_model, d_model),\n",
+ " nn.Dropout(dropout)\n",
+ " )\n",
+ "\n",
+ " def forward(self, x):\n",
+ " return self.net(x)\n",
"\n",
- " src2 = self.ffwd(src)\n",
- " src = src + self.dropout(src2)\n",
- " src = self.norm2(src)\n",
+ "class EncoderNetwork(nn.Module):\n",
+ " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
+ " super().__init__()\n",
+ " self.s_att = UnMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
+ " self.ffwd = FeedForward(d_model, dropout)\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.norm = RMSNorm(d_model, eps=norm_eps)\n",
+ "\n",
+ " def forward(self, src):\n",
+ " src_att = self.s_att(self.norm(src))\n",
+ " src_out = src + self.dropout(src_att)\n",
"\n",
- " return src\n",
+ " src = self.ffwd(self.norm(src_out))\n",
+ " src_f = src_out + self.dropout(src)\n",
"\n",
- "class DecoderLayer(nn.Module):\n",
- " \"\"\" Decoder Layer \"\"\"\n",
+ " del src_att, src_out, src\n",
+ " return src_f\n",
"\n",
- " def __init__(self, d_embd, n_head, dropout, block_size) -> None:\n",
+ "class DecoderNetwork(nn.Module):\n",
+ " def __init__(self, d_model, n_head, norm_eps, dropout, block_size):\n",
" super().__init__()\n",
- " self.s_att = MultiHeadAttention(d_embd=d_embd, n_head=n_head, block_size=block_size, dropout=dropout)\n",
- " self.enc_att = EncoderDecoderAttention(d_embd=d_embd, n_head=n_head, block_size=block_size, dropout=dropout)\n",
- " self.ffwd = FeedForward(d_embd=d_embd)\n",
+ " self.m_att = CasualMaskedAttention(n_head=n_head, d_model=d_model, dropout=dropout, block_size=block_size)\n",
+ " self.f_att = FinalAttention(d_model=d_model, n_head=n_head, dropout=dropout, block_size=block_size)\n",
+ " self.ffwd = FeedForward(d_model, dropout)\n",
" self.dropout = nn.Dropout(dropout)\n",
- " self.norm1 = nn.LayerNorm(d_embd)\n",
- " self.norm2 = nn.LayerNorm(d_embd)\n",
- " self.norm3 = nn.LayerNorm(d_embd)\n",
+ " self.norm = RMSNorm(d_model, eps=norm_eps)\n",
"\n",
- " def forward(self, trg, enc_src, trg_mask=None, src_mask=None):\n",
- " trg2 = self.s_att(trg)\n",
- " trg = trg2 + self.dropout(trg2)\n",
- " trg = self.norm1(trg)\n",
+ " def forward(self, src, att):\n",
+ " m_att_out = self.m_att(self.norm(src))\n",
+ " m_out = src + self.dropout(m_att_out)\n",
"\n",
- " trg2 = self.enc_att(trg, enc_src, enc_src)\n",
- " trg = trg + self.dropout(trg2)\n",
- " trg = self.norm2(trg)\n",
+ " f_out = self.f_att(self.norm(m_out), self.norm(att))\n",
+ " f_out = m_out + self.dropout(f_out)\n",
"\n",
- " trg2 = self.ffwd(trg)\n",
- " trg = trg + self.dropout(trg2)\n",
- " trg = self.norm3(trg)\n",
+ " src_f = self.ffwd(self.norm(f_out))\n",
+ " src_f = f_out + self.dropout(src_f)\n",
"\n",
- " return trg\n",
+ " del f_out, m_out, m_att_out, src, att\n",
+ " return src_f\n",
"\n",
"class Transformer(nn.Module):\n",
- " def __init__(self):\n",
+ " def __init__(self, vocab_size):\n",
" super().__init__()\n",
- " self.d_embd = d_embd\n",
" self.block_size = block_size\n",
- "\n",
- " self.token_embd = nn.Embedding(vocab_size, d_embd)\n",
- " self.pos_embd = nn.Embedding(block_size, d_embd)\n",
- " self.enc_layer = nn.ModuleList([EncoderLayer(n_head=n_head, block_size=block_size, dropout=dropout, d_embd=d_embd) for _ in range(n_layers)])\n",
- " self.dec_layer = nn.ModuleList([DecoderLayer(n_head=n_head, block_size=block_size, dropout=dropout, d_embd=d_embd) for _ in range(n_layers)])\n",
- "\n",
- " self.norm_final = nn.LayerNorm(d_embd)\n",
- " self.lm_head = nn.Linear(d_embd, vocab_size)\n",
- " self.fc_out = nn.Linear(d_embd, vocab_size)\n",
+ " self.toked_model = nn.Embedding(vocab_size, d_model)\n",
+ " self.pos_encod = nn.Embedding(block_size, d_model)\n",
+ " self.enc_layer = nn.ModuleList([EncoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
+ " self.dec_layer = nn.ModuleList([DecoderNetwork(n_head=n_head, norm_eps=norm_eps, block_size=block_size, dropout=dropout, d_model=d_model) for _ in range(n_layers)])\n",
+ " self.norm_final = RMSNorm(d_model, eps=norm_eps)\n",
+ " self.linear_final = nn.Linear(d_model, vocab_size)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.apply(self._init_weights)\n",
"\n",
" def _init_weights(self, module):\n",
+ " \"\"\"\n",
+ " initialize weights of linear and embedding layers\n",
+ "\n",
+ " Args:\n",
+ " - module (nn.Module): the module to initialize weights for\n",
+ " \"\"\"\n",
" if isinstance(module, nn.Linear):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
" if module.bias is not None:\n",
- " torch.nn.init.zeros_(module.bias)\n",
- " elif isinstance(module, nn.Embedding) and module.weight.numel() > 0:\n",
+ " torch.nn.init.zeros_(module.bias.data)\n",
+ " elif isinstance(module, nn.Embedding):\n",
" torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n",
"\n",
- " def make_src_mask(self, src):\n",
- " src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)\n",
- " return src_mask\n",
+ " def forward(self, idx, targets=None):\n",
+ " \"\"\"\n",
+ " forward pass of the transformer model\n",
"\n",
- " def make_trg_mask(self, trg):\n",
- " trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)\n",
- " trg_len = trg.shape[1]\n",
- " trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=trg.device)).bool()\n",
- " trg_mask = trg_pad_mask & trg_sub_mask\n",
- " return trg_mask\n",
+ " Args:\n",
+ " - idx (Tensor): input tensor representing token indices\n",
+ " - targets (Tensor): target tensor for computing loss during training\n",
"\n",
- " def forward(self, idx, targets=None):\n",
+ " Returns:\n",
+ " - logits (Tensor): output logits from the final linear layer\n",
+ " - loss (Tensor): optional. computed cross-entropy loss if targets are provided, else None\n",
+ " \"\"\"\n",
" B, T = idx.shape\n",
"\n",
- " tok_embd = self.token_embd(idx)\n",
- " pos_embd = self.pos_embd(torch.arange(T, device=device))\n",
- " x = tok_embd + pos_embd\n",
+ " toked_model = self.toked_model(idx)\n",
+ " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
+ " x = toked_model + pos_encod\n",
"\n",
" for layer in self.enc_layer:\n",
- " x = layer(x, None)\n",
+ " x_out = layer(x)\n",
"\n",
" for layer in self.dec_layer:\n",
- " x = layer(x, x)\n",
+ " x_final = layer(x, x_out)\n",
"\n",
- " x = self.norm_final(x)\n",
- " logits = self.lm_head(x)\n",
+ " x_final = self.norm_final(x_final)\n",
+ " logits = self.linear_final(x_final)\n",
"\n",
" if targets is None:\n",
" loss = None\n",
@@ -342,34 +474,156 @@
" B, T, C = logits.shape\n",
" logits = logits.view(B*T, C)\n",
" targets = targets.view(B*T)\n",
- " loss = F.cross_entropy(logits, targets, ignore_index=-52, reduction='mean')\n",
+ " loss = F.cross_entropy(logits, targets)\n",
"\n",
" return logits, loss\n",
"\n",
- " def generate(self, idx, max_tokens=50):\n",
- " for _ in range(max_tokens):\n",
- " idx_cond = idx[:, -self.block_size: ]\n",
- " logits, loss = self(idx_cond)\n",
+ " def generate(self, idx, max_new_tokens, temperature=1.0, top_k=0):\n",
+ " \"\"\"\n",
+ " generate new tokens using the trained model\n",
+ "\n",
+ " Args:\n",
+ " - idx (Tensor): input tensor representing initial token indices\n",
+ " - max_new_tokens (int): max no of new tokens to generate\n",
+ " - temperature (float): softmax temperature for sampling\n",
+ " - top_k (int): no of top tokens to consider in sampling\n",
+ "\n",
+ " Returns:\n",
+ " - generated_tokens (list): list of generated token indices\n",
+ " \"\"\"\n",
+ " generated_tokens = []\n",
+ "\n",
+ " for _ in range(max_new_tokens):\n",
+ " idx_cond = idx[:, -self.block_size:]\n",
+ " logits, _ = self(idx_cond)\n",
" logits = logits[:, -1, :]\n",
- " probs = F.softmax(logits, dim=-1)\n",
- " idx_next = torch.multinomial(probs, num_samples=1)\n",
- " idx = torch.cat((idx, idx_next), dim=1)\n",
"\n",
- " return idx, loss\n",
+ " scaled_logits = logits / temperature\n",
+ " if top_k > 0:\n",
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
+ "\n",
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
+ " sampled_idx = torch.multinomial(probs, num_samples=1)\n",
+ " generated_tokens.append(sampled_idx.item())\n",
+ " idx = torch.cat((idx, sampled_idx), dim=1)\n",
+ "\n",
+ " return generated_tokens\n",
+ "\n",
+ " def generate_masked_tokens(self, idx, masked_indices, temperature=1.0, top_k=0):\n",
+ " \"\"\"\n",
+ " Generate predictions for masked tokens using the trained model.\n",
+ "\n",
+ " Args:\n",
+ " - idx (Tensor): input tensor representing token indices\n",
+ " - masked_indices (Tensor): tensor of indices indicating masked positions\n",
+ " - temperature (float): softmax temperature for sampling\n",
+ " - top_k (int): no of top tokens to consider in sampling\n",
+ "\n",
+ " Returns:\n",
+ " - predicted_tokens (Tensor): tensor of predicted token indices\n",
+ " \"\"\"\n",
+ " B, T = idx.shape\n",
+ "\n",
+ " toked_model = self.toked_model(idx)\n",
+ " pos_encod = self.pos_encod(torch.arange(T, device=device))\n",
+ " x = toked_model + pos_encod\n",
+ "\n",
+ " for layer in self.enc_layer:\n",
+ " x_out = layer(x)\n",
+ "\n",
+ " for layer in self.dec_layer:\n",
+ " x_final = layer(x, x_out)\n",
+ "\n",
+ " x_masked = x_final.clone()\n",
+ " x_masked[masked_indices] = self.toked_model(torch.tensor([6], device=device))\n",
+ "\n",
+ " x_masked = self.norm_final(x_masked)\n",
+ " logits = self.linear_final(x_masked)\n",
+ "\n",
+ " masked_logits = logits[masked_indices].view(-1, logits.size(-1))\n",
+ " scaled_logits = masked_logits / temperature\n",
+ " if top_k > 0:\n",
+ " scaled_logits = self._top_k_filtering(scaled_logits, top_k)\n",
+ "\n",
+ " probs = F.softmax(scaled_logits, dim=-1)\n",
+ " predicted_indices = torch.argmax(probs, dim=-1)\n",
"\n",
+ " return predicted_indices\n",
"\n",
- "model = Transformer()\n",
- "# checkpoint_path = '/content/drive/MyDrive/52.9_transformer_model.pth'\n",
- "# checkpoint = torch.load(checkpoint_path)\n",
- "# model.load_state_dict(checkpoint)\n",
+ " def _top_k_filtering(self, logits, top_k):\n",
+ " \"\"\"\n",
+ " filter logits to keep only the top-k tokens\n",
+ "\n",
+ " Args:\n",
+ " - logits (Tensor): input tensor representing unscaled logits\n",
+ " - top_k (int): no of top tokens to keep\n",
+ "\n",
+ " Returns:\n",
+ " - filtered_logits (Tensor): filtered logits with only top-k tokens remaining\n",
+ " \"\"\"\n",
+ " values, indices = torch.topk(logits, top_k, dim=-1)\n",
+ " min_value = values[:, -1].unsqueeze(-1).expand_as(logits)\n",
+ " filtered_logits = torch.where(logits < min_value, torch.ones_like(logits) * -float('inf'), logits)\n",
+ "\n",
+ " return filtered_logits"
+ ],
+ "metadata": {
+ "id": "OusOJ_H8gARB"
+ },
+ "execution_count": 11,
+ "outputs": []
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "\"\"\"\n",
+ " use this file to train the model\n",
+ "\n",
+ " working:\n",
+ " - imports vatious dependencies first, and then loads the training data\n",
+ " - tokenizes it, per-character basis\n",
+ " - loads the required hyper-parameters and the model file\n",
+ " - trains it till 'max_iters' and saves the model state, and generates outputs\n",
+ "\n",
+ " with the current set configuration, model can reach upto ~60million parameters\n",
+ " and can become ~99% accurate with next token prediction\n",
+ "\"\"\"\n",
+ "\n",
+ "torch.manual_seed(1400)\n",
+ "# data loading\n",
+ "def get_batch(split):\n",
+ " # generate a small batch of data of inputs x and targets y\n",
+ " data = train_data if split == 'train' else val_data\n",
+ " ix = torch.randint(len(data) - block_size, (batch_size,))\n",
+ " x = torch.stack([data[i:i+block_size] for i in ix])\n",
+ " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n",
+ " x, y = x.to(device), y.to(device)\n",
+ " return x, y\n",
+ "\n",
+ "@torch.no_grad()\n",
+ "def estimate_loss():\n",
+ " out = {}\n",
+ " model.eval()\n",
+ " for split in ['train', 'val']:\n",
+ " losses = torch.zeros(eval_iters)\n",
+ " for k in range(eval_iters):\n",
+ " X, Y = get_batch(split)\n",
+ " logits, loss = model(X, Y)\n",
+ " losses[k] = loss.item()\n",
+ " out[split] = losses.mean()\n",
+ " model.train()\n",
+ " return out\n",
+ "\n",
+ "vocab_size = tokenizer.n_vocab\n",
+ "model = Transformer(vocab_size)\n",
"m = model.to(device)\n",
"\n",
"# no of parameters\n",
"n_param = sum(p.numel() for p in m.parameters())/1e6\n",
- "print(n_param, 'million')\n",
- "\n",
- "# optimizer\n",
+ "print(f\"vocab size: {vocab_size}\")\n",
+ "print(f\"{n_param:.0f} million parameters\")\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n",
+ "\n",
"steps = []\n",
"train_losses = []\n",
"val_losses = []\n",
@@ -391,22 +645,46 @@
" optimizer.step()"
],
"metadata": {
- "id": "OusOJ_H8gARB"
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "dORKqYKmPmit",
+ "outputId": "62d9926c-a50f-427b-abcf-bb71ec82348e"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 12,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "vocab size: 50281\n",
+ "886 million parameters\n",
+ "step 0: train loss 10.9287, val loss 10.9229\n",
+ "step 100: train loss 8.3812, val loss 9.0325\n",
+ "step 200: train loss 7.2081, val loss 8.0959\n",
+ "step 300: train loss 6.7217, val loss 7.8882\n",
+ "step 400: train loss 6.5446, val loss 7.8266\n",
+ "step 500: train loss 6.8072, val loss 7.8396\n",
+ "step 600: train loss 6.4265, val loss 7.6559\n",
+ "step 700: train loss 6.3871, val loss 7.7765\n",
+ "step 800: train loss 6.4383, val loss 7.5266\n",
+ "step 900: train loss 6.2296, val loss 7.3788\n",
+ "step 999: train loss 6.3129, val loss 7.3048\n"
+ ]
+ }
+ ]
},
{
"cell_type": "code",
"source": [
- "# save the trained model\n",
- "torch.save(model.state_dict(), f\"{n_param:.1f}_model_dict.pth\")\n",
- "torch.save(model, f\"{n_param:.1f}_model.pth\")"
+ "model_save_name = f'aiva_base-{n_param:.0f}m.pth'\n",
+ "path = f\"/content/drive/MyDrive/{model_save_name}\"\n",
+ "torch.save(model.state_dict(), path)"
],
"metadata": {
"id": "e6NM24zMhH_2"
},
- "execution_count": null,
+ "execution_count": 15,
"outputs": []
},
{
@@ -425,25 +703,63 @@
"plt.show()"
],
"metadata": {
- "id": "mmFhDw7KhK0t"
+ "id": "mmFhDw7KhK0t",
+ "colab": {
+ "base_uri": "https://localhost:8080/",
+ "height": 564
+ },
+ "outputId": "d2bd4dcf-f197-447b-8f19-48aaa3f3c0d2"
},
- "execution_count": null,
- "outputs": []
+ "execution_count": 17,
+ "outputs": [
+ {
+ "output_type": "display_data",
+ "data": {
+ "text/plain": [
+ "