From 9a0603908faefc1cecb58c9a54c0695f8baae385 Mon Sep 17 00:00:00 2001 From: Kye Date: Sun, 31 Mar 2024 20:16:36 -0700 Subject: [PATCH] [BUFG][RMSNorm Fix] --- README.md | 12 ++++++++---- bit_linear_new.py | 5 +++-- bitnet/__init__.py | 2 +- bitnet/{bit_linear.py => bit_linear_new.py} | 5 +++-- bitnet/bitlinear.py | 7 ++++--- example.py | 3 ++- pyproject.toml | 2 +- 7 files changed, 22 insertions(+), 14 deletions(-) rename bitnet/{bit_linear.py => bit_linear_new.py} (91%) diff --git a/README.md b/README.md index cdefcd3..bdba960 100644 --- a/README.md +++ b/README.md @@ -33,7 +33,7 @@ import torch from bitnet import BitLinear # Input -x = torch.randn(10, 512) +x = torch.randn(10, 1000, 512) # BitLinear layer layer = BitLinear(512, 400) @@ -42,6 +42,7 @@ layer = BitLinear(512, 400) y = layer(x) print(y) + ``` ### BitLinearNew @@ -50,17 +51,20 @@ import torch from bitnet import BitLinearNew # Create a random tensor of shape (16, 10) -x = torch.randn(16, 10) +x = torch.randn(16, 1000, 512) # Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups -layer = BitLinearNew(10, 20, num_groups=2) +layer = BitLinearNew( + 512, + 20, +) # Perform a forward pass through the BitLinearNew layer with input x output = layer(x) # Print the output tensor print(output) - +print(output.shape) ``` ---- diff --git a/bit_linear_new.py b/bit_linear_new.py index 10a77dc..5e972fe 100644 --- a/bit_linear_new.py +++ b/bit_linear_new.py @@ -2,11 +2,11 @@ from bitnet import BitLinearNew # Create a random tensor of shape (16, 10) -x = torch.randn(16, 10) +x = torch.randn(16, 1000, 512) # Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups layer = BitLinearNew( - 10, + 512, 20, ) @@ -15,3 +15,4 @@ # Print the output tensor print(output) +print(output.shape) \ No newline at end of file diff --git a/bitnet/__init__.py b/bitnet/__init__.py index 9664030..b33e93e 100644 --- a/bitnet/__init__.py +++ b/bitnet/__init__.py @@ -1,6 +1,6 @@ from bitnet.bit_attention import BitMGQA from bitnet.bit_ffn import BitFeedForward -from bitnet.bit_linear import BitLinearNew +from bitnet.bit_linear_new import BitLinearNew from bitnet.bit_transformer import BitNetTransformer from bitnet.bitlinear import BitLinear from bitnet.inference import BitNetInference diff --git a/bitnet/bit_linear.py b/bitnet/bit_linear_new.py similarity index 91% rename from bitnet/bit_linear.py rename to bitnet/bit_linear_new.py index bf03f21..e0248cb 100644 --- a/bitnet/bit_linear.py +++ b/bitnet/bit_linear_new.py @@ -1,5 +1,5 @@ from torch import nn, Tensor -from zeta.nn import RMSNorm +# from zeta.nn import RMSNorm import torch.nn.functional as F @@ -51,7 +51,8 @@ def forward(self, x: Tensor) -> Tensor: """ w = self.weight - x_norm = RMSNorm(self.in_features)(x) + # x_norm = RMSNorm(self.in_features)(x) + x_norm = nn.LayerNorm(self.in_features)(x) # STE using detach x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() diff --git a/bitnet/bitlinear.py b/bitnet/bitlinear.py index 2714650..2718ade 100644 --- a/bitnet/bitlinear.py +++ b/bitnet/bitlinear.py @@ -1,5 +1,5 @@ from torch import nn, Tensor -from zeta.nn import RMSNorm +# from zeta.nn import RMSNorm import torch.nn.functional as F @@ -50,9 +50,10 @@ def forward(self, x: Tensor) -> Tensor: Tensor: The output tensor. """ - b, s, d = x.shape w = self.weight - x_norm = RMSNorm(d)(x) + # x_norm = RMSNorm(self.in_features)(x) + x_norm = nn.LayerNorm(self.in_features)(x) + print(x_norm.shape) # STE using detach x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() diff --git a/example.py b/example.py index 5c481ee..cc8d85a 100644 --- a/example.py +++ b/example.py @@ -3,7 +3,7 @@ from bitnet import BitLinear # Input -x = torch.randn(10, 512) +x = torch.randn(10, 1000, 512) # BitLinear layer layer = BitLinear(512, 400) @@ -12,3 +12,4 @@ y = layer(x) print(y) +print(y.shape) \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index d156c63..0cf56ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api" [tool.poetry] name = "bitnet" -version = "0.1.8" +version = "0.1.9" description = "bitnet - Pytorch" license = "MIT" authors = ["Kye Gomez "]