Skip to content

Commit

Permalink
[BUFG][RMSNorm Fix]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Apr 1, 2024
1 parent fe7086e commit 9a06039
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 14 deletions.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import torch
from bitnet import BitLinear

# Input
x = torch.randn(10, 512)
x = torch.randn(10, 1000, 512)

# BitLinear layer
layer = BitLinear(512, 400)
Expand All @@ -42,6 +42,7 @@ layer = BitLinear(512, 400)
y = layer(x)

print(y)

```

### BitLinearNew
Expand All @@ -50,17 +51,20 @@ import torch
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(16, 10)
x = torch.randn(16, 1000, 512)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(10, 20, num_groups=2)
layer = BitLinearNew(
512,
20,
)

# Perform a forward pass through the BitLinearNew layer with input x
output = layer(x)

# Print the output tensor
print(output)

print(output.shape)
```
----

Expand Down
5 changes: 3 additions & 2 deletions bit_linear_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from bitnet import BitLinearNew

# Create a random tensor of shape (16, 10)
x = torch.randn(16, 10)
x = torch.randn(16, 1000, 512)

# Create an instance of the BitLinearNew class with input size 10, output size 20, and 2 groups
layer = BitLinearNew(
10,
512,
20,
)

Expand All @@ -15,3 +15,4 @@

# Print the output tensor
print(output)
print(output.shape)
2 changes: 1 addition & 1 deletion bitnet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from bitnet.bit_attention import BitMGQA
from bitnet.bit_ffn import BitFeedForward
from bitnet.bit_linear import BitLinearNew
from bitnet.bit_linear_new import BitLinearNew
from bitnet.bit_transformer import BitNetTransformer
from bitnet.bitlinear import BitLinear
from bitnet.inference import BitNetInference
Expand Down
5 changes: 3 additions & 2 deletions bitnet/bit_linear.py → bitnet/bit_linear_new.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn, Tensor
from zeta.nn import RMSNorm
# from zeta.nn import RMSNorm
import torch.nn.functional as F


Expand Down Expand Up @@ -51,7 +51,8 @@ def forward(self, x: Tensor) -> Tensor:
"""
w = self.weight
x_norm = RMSNorm(self.in_features)(x)
# x_norm = RMSNorm(self.in_features)(x)
x_norm = nn.LayerNorm(self.in_features)(x)

# STE using detach
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
Expand Down
7 changes: 4 additions & 3 deletions bitnet/bitlinear.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn, Tensor
from zeta.nn import RMSNorm
# from zeta.nn import RMSNorm
import torch.nn.functional as F


Expand Down Expand Up @@ -50,9 +50,10 @@ def forward(self, x: Tensor) -> Tensor:
Tensor: The output tensor.
"""
b, s, d = x.shape
w = self.weight
x_norm = RMSNorm(d)(x)
# x_norm = RMSNorm(self.in_features)(x)
x_norm = nn.LayerNorm(self.in_features)(x)
print(x_norm.shape)

# STE using detach
x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach()
Expand Down
3 changes: 2 additions & 1 deletion example.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from bitnet import BitLinear

# Input
x = torch.randn(10, 512)
x = torch.randn(10, 1000, 512)

# BitLinear layer
layer = BitLinear(512, 400)
Expand All @@ -12,3 +12,4 @@
y = layer(x)

print(y)
print(y.shape)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "poetry.core.masonry.api"

[tool.poetry]
name = "bitnet"
version = "0.1.8"
version = "0.1.9"
description = "bitnet - Pytorch"
license = "MIT"
authors = ["Kye Gomez <kye@apac.ai>"]
Expand Down

0 comments on commit 9a06039

Please sign in to comment.