diff --git a/audio engine/model.py b/audio engine/model.py index 413e0cc..a716e93 100644 --- a/audio engine/model.py +++ b/audio engine/model.py @@ -1,11 +1,13 @@ import torch -import torch.nn as nn +from torch import Tensor, nn import torch.nn.functional as F import math +from typing import Dict, Iterable, Optional +import numpy as np device = 'cuda' if torch.cuda.is_available() else 'cpu' -class ConfigModel(): +class ConfigModel: d_model = 768 block_size = 1024 n_head = 12 @@ -50,6 +52,32 @@ def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight +class Linear(nn.Linear): + def forward(self, x: Tensor) -> Tensor: + return F.linear( + x, + self.weight.to(x.dtype), + None if self.bias is None else self.bias.to(x.dtype), + ) + +class Conv1d(nn.Conv1d): + def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: + return super()._conv_forward( + x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype) + ) + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + +class NewGELU(nn.Module): + def forward(self, input): + return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) + class PositionalEncoding(nn.Module): def __init__(self, d_model, block_size, dropout): super().__init__() @@ -66,10 +94,6 @@ def forward(self, x): x = x + self.pe[:x.size(0), :] return self.dropout(x) -class NewGELU(nn.Module): - def forward(self, input): - return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0)))) - class SelfAttention(nn.Module): def __init__(self, head_size, d_model, block_size, dropout): super().__init__() diff --git a/audio engine/whisper.py.py b/audio engine/whisper.py similarity index 100% rename from audio engine/whisper.py.py rename to audio engine/whisper.py diff --git a/test.py b/test.py index db3ef25..741f45c 100644 --- a/test.py +++ b/test.py @@ -1,5 +1,8 @@ import torch import tiktoken +import torch.nn as nn +import math +import numpy as np tokenizer = tiktoken.get_encoding("p50k_base") tokenizer = tiktoken.encoding_for_model("text-davinci-003") @@ -13,4 +16,42 @@ B, T = x.shape z = x.view(B*T) -print(z) \ No newline at end of file +print(z) + +class PositionalEncoding(nn.Module): + def __init__(self, d_model, block_size, dropout): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + pe = torch.zeros(block_size, d_model) + position = torch.arange(0, block_size, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, x): + x = x + self.pe[:x.size(0), :] + return self.dropout(x) + +def sinusoids(length, channels, max_timescale=10000): + """Returns sinusoids for positional embedding""" + assert channels % 2 == 0 + log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) + inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) + scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] + return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) + +length = 100 +channels = 100 +positional_embeddings = sinusoids(length, channels).numpy() + +import matplotlib.pyplot as plt + +plt.figure(figsize=(10, 8)) +plt.imshow(positional_embeddings, aspect='auto', cmap='viridis') +plt.colorbar() +plt.xlabel('Embedding Dimension') +plt.ylabel('Position') +plt.title('Sinusoidal Positional Embeddings') +plt.show() \ No newline at end of file