-
Notifications
You must be signed in to change notification settings - Fork 5
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deleting cross attention between documents during pertaining #16
base: main
Are you sure you want to change the base?
Conversation
Hi! In the last commit I include to I updated the
Thought-out this week I'll run some benchmarks and resolve the 2 issues previously mentioned. |
In this PR I include the mechanism to delete cross attention between different documents during pertaining. I'm developing this PR from #14 as it is reusing most of the code for the Llama model.
To use this feature you will need to tokenise the data with the updated tool (Just contains a little patch that soon will be merged into main to NOT shuffle tokens with datatrove) and it's necessary to add the
eos_token
.python3 tools/preprocess_data.py --tokenizer-name-or-path meta-llama/Meta-Llama-3-8B --eos-token "<|end_of_text|>" --output-folder datasets/SlimPajama-6B --n-tasks 16 hf --dataset DKYoon/SlimPajama-6B
Then in the config file you need to set
remove_document_xattention
toTrue
:This will build the
LlamaForSFT
model and config the collator to produce the correctposition_ids
& `label_mask.In the collator we will create the correct position ids for each document, in short, reseting the position ids after every
eos_token
. For thelabel_mask
we will mask the loss for the preceding token of theeos_token
so the model doesn't learns to predict theeos_token
and for theeos_token
as it doesn't makes any sense to compute the loss of the prediction of theeos_token
. In this image you can see what we are feeding into the model, with the correct values for theposition_ids
andlabel_mask
.And another example for a sample with > 2 documents (This is the boundary of the 3rd & 4th document):
The main difference between
LlamaForSFT
&LlamaForTraining
is thatLlamaForTraining
leverages FA RoPE embeddings with triton for better performance but doesn't support position ids. If we manage to develop a custom triton kernel for RoPEs with position ids we could keep justLlamaForSFT
.