Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

qlora #100

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

qlora #100

wants to merge 1 commit into from

Conversation

RanchiZhao
Copy link

@RanchiZhao RanchiZhao commented Jul 28, 2023

This PR mainly involves the following aspects:

  • QLoRA overall logic:

    • First, quantize the model parameter files.
    • Set the int4 field in the model's config to enable QLoRA fine-tuning.
    • The rest is consistent with basic task fine-tuning.
  • Modifications to the model structure:

    • Add a bool type field int4 in the model parameter files in the folder src/config, which acts as a switch to control whether to use QLoRA. Corresponding adjustments need to be made in other relevant structures (Attention/SelfAttentionBlock/FFNBlock/TransformerBlock/DenseGatedACT/FeedForward/Encoder/CPMBee) to load the appropriate models based on the int4 field.
    • In src/cpm_live/layers/feedforward.py, add class Linear4bit as the QLoRA method linear layer; add class Params4bit as the weight for Linear4bit; add class DistributedParameter4Int8 to meet encapsulation needs.
  • Add scripts/sample code/README:

    • src/quantize_state_dict.py is the code for compressing the initial weights. QLoRA needs to load the compressed dict as model weights.
    • src/finetune_cpm_bee_qlora.py is the fine-tuning sample code.
    • src/scripts/finetune_cpm_bee_qlora.sh is the fine-tuning sample script.
    • tutorials/basic_task_finetune/README_qlora.md is the fine-tuning tutorial for QLoRA.
  • Other considerations:

    • The inspect part of the code has been commented out in src/finetune_cpm_bee_qlora.py, as uint8 does not support std and var.
    • It's necessary to synchronize and modify the bug in BMTrain.blocklayer where uint8 type requires_grad cannot be passed in.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant