Skip to content

Implementation of LoRA (Low-Rank Adaptation of Large Language Models) for GPT-J and GPT-2

License

Notifications You must be signed in to change notification settings

BRAIN-chain/LoRA

Repository files navigation

Small-Scale Matrices

LoRA Library for GPT-J and GPT-2

LoRA implementations for GPT-J and GPT-2.

In source code,

  • Change float32 to float16 if needed.
  • Change cpu to cuda if available.
  • Change adapter_dim if needed.

FYI, The official GPT-2 LoRA implementation: microsoft/LoRA

How to Use

Monkey-patch GPT-J for convenience. For example:

class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
    def __init__(self, config):
        super().__init__(config)
        convert_to_lora(self.attn)
        convert_to_lora(self.mlp)

class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
    def __init__(self, config):
        super().__init__(config)
        convert_to_lora(self)


class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        convert_to_lora(self)

transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock

Now you can use LoRA-applying GPT-J just like the original one:

model = GPTJForCausalLM.from_pretrained(
    "EleutherAI/gpt-j-6B", revision="float16",
    torch_dtype=torch.float16, low_cpu_mem_usage=True
).to(device='cuda', non_blocking=True)

Test

For example, get gpt-j-6B information through:

$ python test/gptj_lora.py

Model Loaded.

Default Model
- # of           params:        6050882784
- # of trainable params:        6050882784
- # of          buffers:        117440540
Model Loaded.

LoRA-applied Model
- # of           params:        35635424
- # of trainable params:        34774016
- # of          buffers:        6167461916

Adapters Saved:                 69733529

References

About

Implementation of LoRA (Low-Rank Adaptation of Large Language Models) for GPT-J and GPT-2

Resources

License

Stars

Watchers

Forks