generated from kyegomez/Python-Package-Template
-
-
Notifications
You must be signed in to change notification settings - Fork 157
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #29 from Sunwood-ai-labs/jp
Jp
- Loading branch information
Showing
3 changed files
with
346 additions
and
138 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,198 @@ | ||
[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) | ||
|
||
# BitNet | ||
![bitnet](/bitnet.png) | ||
PyTorch Implementation of the linear methods and model from the paper "BitNet: Scaling 1-bit Transformers for Large Language Models" | ||
|
||
[Paper link:](https://arxiv.org/pdf/2310.11453.pdf) | ||
|
||
BitLinear = tensor -> layernorm -> Binarize -> abs max quantization -> dequant | ||
|
||
"The implementation of the BitNet architecture is quite simple, requiring only the replacement of linear projections (i.e., nn.Linear in PyTorch) in the Transformer. " -- BitNet is really easy to implement just swap out the linears with the BitLinear modules! | ||
|
||
## **NEWS** | ||
- BitNet Transformer has been trained using the `train.py` file that trains on enwiki8 a small 1gb dataset of wikipedia: [HERE IS THE LINK](https://drive.google.com/file/d/1gBuZRFBqMV3cVD902LXA_hmZl4e0dLyY/view) | ||
- **New Iteration** 🔥 There is an all-new iteration from the paper "[The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764)", we're implementing it now. Join the Agora discord and contribute! [Join Here](https://discord.gg/hFzevCjG8c) | ||
- **New Optimizations** The first `BitLinear` has been optimized and we now have a Bit Attention `BitMGQA` That implements BitLinear into the attention mechanism. Multi Grouped Query Attention is also widely recognized as the best attention for its fast decoding and long context handling, thanks to Frank for his easy to use implementation! | ||
|
||
## Appreciation | ||
- Dimitry, Nullonix for analysis and code review and revision | ||
- Vyom, for providing 4080 to train! | ||
|
||
## Installation | ||
`pip install bitnet` | ||
|
||
## Usage: | ||
|
||
### `BitLinear` | ||
- Example of the BitLinear layer which is the main innovation of the paper! | ||
```python | ||
import torch | ||
|
||
from bitnet import BitLinear | ||
|
||
# Input | ||
x = torch.randn(10, 512) | ||
|
||
# BitLinear layer | ||
layer = BitLinear(512, 400) | ||
|
||
# Output | ||
y = layer(x) | ||
|
||
print(y) | ||
``` | ||
---- | ||
|
||
### `BitNetTransformer` | ||
- Fully implemented Transformer as described in the diagram with MHA, and BitFeedforwards | ||
- Can be utilized not just for text but for images and maybe even video or audio processing | ||
- Complete with residuals and skip connections for gradient flow | ||
|
||
```python | ||
# Import the necessary libraries | ||
import torch | ||
from bitnet import BitNetTransformer | ||
|
||
# Create a random tensor of integers | ||
x = torch.randint(0, 20000, (1, 1024)) | ||
|
||
# Initialize the BitNetTransformer model | ||
bitnet = BitNetTransformer( | ||
num_tokens=20000, # Number of unique tokens in the input | ||
dim=1024, # Dimension of the input and output embeddings | ||
depth=6, # Number of transformer layers | ||
heads=8, # Number of attention heads | ||
ff_mult=4, # Multiplier for the hidden dimension in the feed-forward network | ||
) | ||
|
||
# Pass the tensor through the transformer model | ||
logits = bitnet(x) | ||
|
||
# Print the shape of the output | ||
print(logits) | ||
|
||
``` | ||
|
||
|
||
### `BitAttention` | ||
This Attention has been modified to use BitLinear instead of the default linear projection. It's also using Multi-Grouped Query Attention instead of regular multi-head attention for faster decoding and longer context handling. | ||
|
||
```python | ||
import torch | ||
from bitnet import BitMGQA | ||
|
||
# Create a random tensor of shape (1, 10, 512) | ||
x = torch.randn(1, 10, 512) | ||
|
||
# Create an instance of the BitMGQA model with input size 512, 8 attention heads, and 4 layers | ||
gqa = BitMGQA(512, 8, 4) | ||
|
||
# Pass the input tensor through the BitMGQA model and get the output and attention weights | ||
out, _ = gqa(x, x, x, need_weights=True) | ||
|
||
# Print the shapes of the output tensor and attention tensor | ||
print(out) | ||
``` | ||
|
||
### `BitFeedForward` | ||
- Feedforward as shown in the diagram with BitLinear and a GELU: | ||
- Linear -> GELU -> Linear | ||
- You can add dropouts, or layernorms, or other layers for a better ffn | ||
|
||
```python | ||
import torch | ||
from bitnet import BitFeedForward | ||
|
||
# Create a random input tensor of shape (10, 512) | ||
x = torch.randn(10, 512) | ||
|
||
# Create an instance of the BitFeedForward class with the following parameters: | ||
# - input_dim: 512 | ||
# - hidden_dim: 512 | ||
# - num_layers: 4 | ||
# - swish: True (use Swish activation function) | ||
# - post_act_ln: True (apply Layer Normalization after each activation) | ||
# - dropout: 0.1 (apply dropout with a probability of 0.1) | ||
ff = BitFeedForward(512, 512, 4, swish=True, post_act_ln=True, dropout=0.1) | ||
|
||
# Apply the BitFeedForward network to the input tensor x | ||
y = ff(x) | ||
|
||
# Print the shape of the output tensor y | ||
print(y) # torch.Size([10, 512]) | ||
``` | ||
|
||
## Inference | ||
```python | ||
from bitnet import BitNetInference | ||
|
||
bitnet = BitNetInference() | ||
bitnet.load_model("../model_checkpoint.pth") # Download model | ||
output_str = bitnet.generate("The dog jumped over the ", 512) | ||
print(output_str) | ||
``` | ||
|
||
## Huggingface Usage | ||
```python | ||
import torch | ||
from transformers import AutoModelForSequenceClassification, AutoTokenizer | ||
|
||
from bitnet import replace_linears_in_hf | ||
|
||
# Load a model from Hugging Face's Transformers | ||
model_name = "bert-base-uncased" | ||
tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
model = AutoModelForSequenceClassification.from_pretrained(model_name) | ||
|
||
# Replace Linear layers with BitLinear | ||
replace_linears_in_hf(model) | ||
|
||
# Example text to classify | ||
text = "Replace this with your text" | ||
inputs = tokenizer( | ||
text, return_tensors="pt", padding=True, truncation=True, max_length=512 | ||
) | ||
|
||
# Perform inference | ||
model.eval() # Set the model to evaluation mode | ||
with torch.no_grad(): | ||
outputs = model(**inputs) | ||
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | ||
print(predictions) | ||
|
||
# Process predictions | ||
predicted_class_id = predictions.argmax().item() | ||
print(f"Predicted class ID: {predicted_class_id}") | ||
|
||
# Optionally, map the predicted class ID to a label, if you know the classification labels | ||
# labels = ["Label 1", "Label 2", ...] # Define your labels corresponding to the model's classes | ||
# print(f"Predicted label: {labels[predicted_class_id]}") | ||
``` | ||
|
||
# License | ||
MIT | ||
|
||
# Citation | ||
```bibtex | ||
@misc{2310.11453, | ||
Author = {Hongyu Wang and Shuming Ma and Li Dong and Shaohan Huang and Huaijie Wang and Lingxiao Ma and Fan Yang and Ruiping Wang and Yi Wu and Furu Wei}, | ||
Title = {BitNet: Scaling 1-bit Transformers for Large Language Models}, | ||
Year = {2023}, | ||
Eprint = {arXiv:2310.11453}, | ||
} | ||
``` | ||
|
||
|
||
# Todo | ||
- [x] Double check BitLinear implementation and make sure it works exactly as in paper | ||
- [x] Implement training script for `BitNetTransformer` | ||
- [x] Train on Enwiki8, copy and past code and data from Lucidrains repos | ||
- [x] Benchmark performance | ||
- [x] Look into Straight Through Estimator for non-differentiable backprop | ||
- [x] Implement BitFeedForward | ||
- [x] Clean up codebase | ||
- [x] Add unit tests for each module | ||
- [ ] Implement the new BitNet1.5b from the [paper](https://arxiv.org/abs/2402.17764) | ||
- [ ] Implement the BitNet15b in Cuda |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# BitNet | ||
|
||
[![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) | ||
|
||
論文「BitNet: Scaling 1-bit Transformers for Large Language Models」からの線形手法とモデルのPyTorch実装です。 | ||
|
||
[論文リンク:](https://arxiv.org/pdf/2310.11453.pdf) | ||
|
||
BitLinear = テンソル -> レイヤーノーム -> 二値化 -> 絶対最大量子化 -> 逆量子化 | ||
|
||
「BitNetアーキテクチャの実装は非常にシンプルで、Transformer内の線形射影(つまり、PyTorchのnn.Linear)を置換するだけです。」 -- BitNetは実装が本当に簡単で、線形モジュールをBitLinearモジュールに交換するだけです! | ||
## **ニュース** | ||
- BitNet Transformerは、Wikipediaの小さな1GBデータセットであるenwiki8でトレーニングする`train.py`ファイルを使用してトレーニングされました:[こちらがリンクです](https://drive.google.com/file/d/1gBuZRFBqMV3cVD902LXA_hmZl4e0dLyY/view) | ||
- **新しい反復** 🔥 論文「[The Era of 1-bit LLMs: All Large Language Models are in 1.58 Bits](https://arxiv.org/abs/2402.17764) 」からの全く新しい反復があります。現在実装中です。Agoraのdiscordに参加して貢献しましょう! [こちらで参加](https://discord.gg/hFzevCjG8c) | ||
- **新しい最適化** 最初の`BitLinear`が最適化され、注意メカニズムにBitLinearを実装するBit Attention `BitMGQA`を新たに得ました。Multi Grouped Query Attentionは、その高速なデコーディングと長いコンテキスト処理により、最高の注意と広く認識されています。Frankによる使いやすい実装に感謝します! | ||
## 謝辞 | ||
- Dimitry, Nullonixによる分析、コードレビュー、および改訂 | ||
- トレーニング用に4080を提供してくれたVyom! | ||
## インストール | ||
|
||
`pip install bitnet` | ||
## 使用方法: | ||
### `BitLinear` | ||
- 論文の主な革新であるBitLinearレイヤーの例: | ||
|
||
```python | ||
import torch | ||
|
||
from bitnet import BitLinear | ||
|
||
# 入力 | ||
x = torch.randn(10, 512) | ||
|
||
# BitLinearレイヤー | ||
layer = BitLinear(512, 400) | ||
|
||
# 出力 | ||
y = layer(x) | ||
|
||
print(y) | ||
``` | ||
|
||
--- | ||
### `BitNetTransformer` | ||
- MHAとBitFeedforwardsを備えた図に記載されている通りに完全に実装されたTransformer | ||
- テキストだけでなく、画像やビデオ、オーディオ処理にも利用可能 | ||
- 勾配の流れのための残差とスキップ接続を完備 | ||
|
||
```python | ||
# 必要なライブラリをインポート | ||
import torch | ||
from bitnet import BitNetTransformer | ||
|
||
# 整数のランダムテンソルを作成 | ||
x = torch.randint(0, 20000, (1, 1024)) | ||
|
||
# BitNetTransformerモデルを初期化 | ||
bitnet = BitNetTransformer( | ||
num_tokens=20000, # 入力のユニークなトークン数 | ||
dim=1024, # 入力および出力エンベディングの次元 | ||
depth=6, # トランスフォーマーレイヤーの数 | ||
heads=8, # 注意ヘッドの数 | ||
ff_mult=4, # フィードフォワードネットワーク内の隠れ層の次元の倍数 | ||
) | ||
|
||
# テンソルをトランスフォーマーモデルを通して渡す | ||
logits = bitnet(x) | ||
|
||
# 出力の形状を印刷 | ||
print(logits) | ||
``` | ||
|
||
|
||
### `BitAttention` | ||
|
||
このAttentionは、デフォルトの線形射影の代わりにBitLinearを使用するように修正されました。また、通常のマルチヘッドアテンションの代わりにMulti-Grouped Query Attentionを使用しています。これにより、より高速なデコーディングとより長いコンテキスト処理が可能になります。 | ||
|
||
```python | ||
import torch | ||
from bitnet import BitMGQA | ||
|
||
# 形状が(1, 10, 512)のランダムテンソルを作成 | ||
x = torch.randn(1, 10, 512) | ||
|
||
# 入力サイズ512、注意ヘッド8、レイヤー4のBitMGQAモデルのインスタンスを作成 | ||
gqa = BitMGQA(512, 8, 4) | ||
|
||
# 入力テンソルをBitMGQAモデルを通して渡し、出力と注意重みを取得 | ||
out, _ = gqa(x, x, x, need_weights=True) | ||
|
||
# 出力テンソルと注意テンソルの形状を印刷 | ||
print(out) | ||
``` | ||
|
||
|
||
### `BitFeedForward` | ||
- BitLinearとGELUを使用した図に示されているフィードフォワード: | ||
- 線形 -> GELU -> 線形 | ||
- より良いffnのためにドロップアウトやレイヤーノーム、その他のレイヤーを追加できます | ||
|
||
```python | ||
import torch | ||
from bitnet | ||
``` |
Oops, something went wrong.