Skip to content

Commit

Permalink
Merge pull request #29 from Sunwood-ai-labs/jp
Browse files Browse the repository at this point in the history
Jp
  • Loading branch information
kyegomez authored Mar 10, 2024
2 parents db81ec8 + cda2c1f commit 0a89d6d
Show file tree
Hide file tree
Showing 3 changed files with 346 additions and 138 deletions.
198 changes: 198 additions & 0 deletions Docs/README_EN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

# BitNet
![bitnet](/bitnet.png)
PyTorch Implementation of the linear methods and model from the paper "BitNet: Scaling 1-bit Transformers for Large Language Models"

[Paper link:](https://arxiv.org/pdf/2310.11453.pdf)

BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant

"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules!

## **NEWS**
- BitNet Transformer has been trained using the `train.py` file that trains on enwiki8 a small 1gb dataset of wikipedia: [HERE IS THE LINK](https://drive.google.com/file/d/1gBuZRFBqMV3cVD902LXA_hmZl4e0dLyY/view)
- **New Iteration** 🔥 There is an all-new iteration from the paper "[The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)", we're implementing it now. Join the Agora discord and contribute! [Join Here](https://discord.gg/hFzevCjG8c)
- **New Optimizations** The first `BitLinear` has been optimized and we now have a Bit Attention `BitMGQA` That implements BitLinear into the attention mechanism. Multi Grouped Query Attention is also widely recognized as the best attention for its fast decoding and long context handling, thanks to Frank for his easy to use implementation!

## Appreciation
- Dimitry, Nullonix for analysis and code review and revision
- Vyom, for providing 4080 to train!

## Installation
`pip install bitnet`

## Usage:

### `BitLinear`
- Example of the BitLinear layer which is the main innovation of the paper!
```python
import torch

from bitnet import BitLinear

# Input
x = torch.randn(10, 512)

# BitLinear layer
layer = BitLinear(512, 400)

# Output
y = layer(x)

print(y)
```
----

### `BitNetTransformer`
- Fully implemented Transformer as described in the diagram with MHA, and BitFeedforwards
- Can be utilized not just for text but for images and maybe even video or audio processing
- Complete with residuals and skip connections for gradient flow

```python
# Import the necessary libraries
import torch
from bitnet import BitNetTransformer

# Create a random tensor of integers
x = torch.randint(0, 20000, (1, 1024))

# Initialize the BitNetTransformer model
bitnet = BitNetTransformer(
num_tokens=20000, # Number of unique tokens in the input
dim=1024, # Dimension of the input and output embeddings
depth=6, # Number of transformer layers
heads=8, # Number of attention heads
ff_mult=4, # Multiplier for the hidden dimension in the feed-forward network
)

# Pass the tensor through the transformer model
logits = bitnet(x)

# Print the shape of the output
print(logits)

```


### `BitAttention`
This Attention has been modified to use BitLinear instead of the default linear projection. It's also using Multi-Grouped Query Attention instead of regular multi-head attention for faster decoding and longer context handling.

```python
import torch
from bitnet import BitMGQA

# Create a random tensor of shape (1, 10, 512)
x = torch.randn(1, 10, 512)

# Create an instance of the BitMGQA model with input size 512, 8 attention heads, and 4 layers
gqa = BitMGQA(512, 8, 4)

# Pass the input tensor through the BitMGQA model and get the output and attention weights
out, _ = gqa(x, x, x, need_weights=True)

# Print the shapes of the output tensor and attention tensor
print(out)
```

### `BitFeedForward`
- Feedforward as shown in the diagram with BitLinear and a GELU:
- Linear -> GELU -> Linear
- You can add dropouts, or layernorms, or other layers for a better ffn

```python
import torch
from bitnet import BitFeedForward

# Create a random input tensor of shape (10, 512)
x = torch.randn(10, 512)

# Create an instance of the BitFeedForward class with the following parameters:
# - input_dim: 512
# - hidden_dim: 512
# - num_layers: 4
# - swish: True (use Swish activation function)
# - post_act_ln: True (apply Layer Normalization after each activation)
# - dropout: 0.1 (apply dropout with a probability of 0.1)
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=True, dropout=0.1)

# Apply the BitFeedForward network to the input tensor x
y = ff(x)

# Print the shape of the output tensor y
print(y) # torch.Size([10, 512])
```

## Inference
```python
from bitnet import BitNetInference

bitnet = BitNetInference()
bitnet.load_model("../model_checkpoint.pth") # Download model
output_str = bitnet.generate("The dog jumped over the ", 512)
print(output_str)
```

## Huggingface Usage
```python
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

from bitnet import replace_linears_in_hf

# Load a model from Hugging Face's Transformers
model_name = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)

# Replace Linear layers with BitLinear
replace_linears_in_hf(model)

# Example text to classify
text = "Replace this with your text"
inputs = tokenizer(
text, return_tensors="pt", padding=True, truncation=True, max_length=512
)

# Perform inference
model.eval() # Set the model to evaluation mode
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
print(predictions)

# Process predictions
predicted_class_id = predictions.argmax().item()
print(f"Predicted class ID: {predicted_class_id}")

# Optionally, map the predicted class ID to a label, if you know the classification labels
# labels = ["Label 1", "Label 2", ...] # Define your labels corresponding to the model's classes
# print(f"Predicted label: {labels[predicted_class_id]}")
```

# License
MIT

# Citation
```bibtex
@misc{2310.11453,
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei},
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models},
Year = {2023},
Eprint = {arXiv:2310.11453},
}
```


# Todo
- [x] Double check BitLinear implementation and make sure it works exactly as in paper
- [x] Implement training script for `BitNetTransformer`
- [x] Train on Enwiki8, copy and past code and data from Lucidrains repos
- [x] Benchmark performance
- [x] Look into Straight Through Estimator for non-differentiable backprop
- [x] Implement BitFeedForward
- [x] Clean up codebase
- [x] Add unit tests for each module
- [ ] Implement the new BitNet1.5b from the [paper](https://arxiv.org/abs/2402.17764)
- [ ] Implement the BitNet15b in Cuda
104 changes: 104 additions & 0 deletions Docs/README_JP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# BitNet

[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf)

論文「BitNet: Scaling 1-bit Transformers for Large Language Models」からの線形手法とモデルのPyTorch実装です。

[論文リンク:](https://arxiv.org/pdf/2310.11453.pdf)

BitLinear = テンソル -> レイヤーノーム -> 二値化 -> 絶対最大量子化 -> 逆量子化

「BitNetアーキテクチャの実装は非常にシンプルで、Transformer内の線形射影(つまり、PyTorchのnn.Linear)を置換するだけです。」 -- BitNetは実装が本当に簡単で、線形モジュールをBitLinearモジュールに交換するだけです!
## **ニュース**
- BitNet Transformerは、Wikipediaの小さな1GBデータセットであるenwiki8でトレーニングする`train.py`ファイルを使用してトレーニングされました:[こちらがリンクです](https://drive.google.com/file/d/1gBuZRFBqMV3cVD902LXA_hmZl4e0dLyY/view)
- **新しい反復** 🔥 論文「[The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) 」からの全く新しい反復があります。現在実装中です。Agoraのdiscordに参加して貢献しましょう! [こちらで参加](https://discord.gg/hFzevCjG8c)
- **新しい最適化** 最初の`BitLinear`が最適化され、注意メカニズムにBitLinearを実装するBit Attention `BitMGQA`を新たに得ました。Multi Grouped Query Attentionは、その高速なデコーディングと長いコンテキスト処理により、最高の注意と広く認識されています。Frankによる使いやすい実装に感謝します!
## 謝辞
- Dimitry, Nullonixによる分析、コードレビュー、および改訂
- トレーニング用に4080を提供してくれたVyom!
## インストール

`pip install bitnet`
## 使用方法:
### `BitLinear`
- 論文の主な革新であるBitLinearレイヤーの例:

```python
import torch

from bitnet import BitLinear

# 入力
x = torch.randn(10, 512)

# BitLinearレイヤー
layer = BitLinear(512, 400)

# 出力
y = layer(x)

print(y)
```

---
### `BitNetTransformer`
- MHAとBitFeedforwardsを備えた図に記載されている通りに完全に実装されたTransformer
- テキストだけでなく、画像やビデオ、オーディオ処理にも利用可能
- 勾配の流れのための残差とスキップ接続を完備

```python
# 必要なライブラリをインポート
import torch
from bitnet import BitNetTransformer

# 整数のランダムテンソルを作成
x = torch.randint(0, 20000, (1, 1024))

# BitNetTransformerモデルを初期化
bitnet = BitNetTransformer(
num_tokens=20000, # 入力のユニークなトークン数
dim=1024, # 入力および出力エンベディングの次元
depth=6, # トランスフォーマーレイヤーの数
heads=8, # 注意ヘッドの数
ff_mult=4, # フィードフォワードネットワーク内の隠れ層の次元の倍数
)

# テンソルをトランスフォーマーモデルを通して渡す
logits = bitnet(x)

# 出力の形状を印刷
print(logits)
```


### `BitAttention`

このAttentionは、デフォルトの線形射影の代わりにBitLinearを使用するように修正されました。また、通常のマルチヘッドアテンションの代わりにMulti-Grouped Query Attentionを使用しています。これにより、より高速なデコーディングとより長いコンテキスト処理が可能になります。

```python
import torch
from bitnet import BitMGQA

# 形状が(1, 10, 512)のランダムテンソルを作成
x = torch.randn(1, 10, 512)

# 入力サイズ512、注意ヘッド8、レイヤー4のBitMGQAモデルのインスタンスを作成
gqa = BitMGQA(512, 8, 4)

# 入力テンソルをBitMGQAモデルを通して渡し、出力と注意重みを取得
out, _ = gqa(x, x, x, need_weights=True)

# 出力テンソルと注意テンソルの形状を印刷
print(out)
```


### `BitFeedForward`
- BitLinearとGELUを使用した図に示されているフィードフォワード:
- 線形 -> GELU -> 線形
- より良いffnのためにドロップアウトやレイヤーノーム、その他のレイヤーを追加できます

```python
import torch
from bitnet
```
Loading

0 comments on commit 0a89d6d

Please sign in to comment.