This repository demonstrates training your own BitNet model based on the llama2 architecture. Unedited, the script will train a ~84M param model on ~1.5B tokens.
train.py - the entire training process including preparing the data, defining the model architecture, and training model.
utils.py - contains the BitLinear implementation, and convert_to_bitnet function for converting huggingface's LlamaForCausalLM to BitNet.
inference.py - run inference with a trained BitNet model.
I wanted to make this process as straight forward and hackable as possible, so all of these scripts are minimal and easily adjustable.
The script currently uses a 15% subset of openwebtext2 for training. This has been pretokenized at a context length of 256 for ease of testing, but code is also included to tokenize data yourself. You can replace a couple lines in the script to train on pretty much anything else you want.
You'll want to install these packages. The last two are optional and are for logging and HF auth.
The BitLinear definition is copied straight from the released training details manuscript. The BitNet architecture is defined by loading a blank Llama2 model using huggingface, and then making the necessary replacements (as per the manuscript):
- Replace all nn.Linear in attention and SwiGLU with BitLinear
- Remove RMSNorm before attention and SwiGLU because BitLinear has built-in RMSNorm.