Skip to content

Commit

Permalink
Add lora script and docs
Browse files Browse the repository at this point in the history
  • Loading branch information
Blaizzy committed Sep 28, 2024
1 parent f206ded commit dab901c
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 0 deletions.
63 changes: 63 additions & 0 deletions mlx_vlm/LORA.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# lora.py - NanoLLaVA LoRA Training Script

## Overview

`lora.py` is a Python script for fine-tuning a NanoLLaVA model using Low-Rank Adaptation (LoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments.

## Requirements

- Python 3.7+
- MLX VLM library
- Required Python packages: `argparse`, `mlx_vlm`, `mlx`

## Usage

To use the script, run it from the command line with the desired arguments:

```
python lora.py --dataset /path/to/your/dataset [other options]
```

## Arguments

The script accepts the following command-line arguments:

- `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16")
- `--dataset`: Path to your dataset (required)
- `--learning_rate`: Learning rate for the optimizer (default: 1e-4)
- `--batch_size`: Batch size for training (default: 2)
- `--epochs`: Number of epochs to train (default: 1)
- `--steps`: Number of steps per epoch (default: 100)
- `--print_every`: Print loss every n steps (default: 10)
- `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors")

## Example

Here's an example of how to run the script with custom parameters:

```
python lora.py --dataset /path/to/your/dataset --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5
```

This command will:
- Use the dataset at `/path/to/your/dataset`
- Train for 2 epochs
- Perform 200 steps per epoch
- Use a batch size of 4
- Set the learning rate to 5e-5

## Output

The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path.

## Note

Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model.

## Contributing

Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements.

## License

[Specify the license here, e.g., MIT, Apache 2.0, etc.]
99 changes: 99 additions & 0 deletions mlx_vlm/lora.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import argparse

import mlx.optimizers as optim

from mlx_vlm.trainer import Dataset, Trainer
from mlx_vlm.trainer.lora import *
from mlx_vlm.trainer.utils import *
from mlx_vlm.utils import load, load_image_processor


def add_image_token(items, image_token="<image>"):
conversations = []
for item in items["conversations"]:
if item["role"] == "user":
if item["content"].startswith(image_token):
conversations.append({"role": "user", "content": item["content"]})
else:
conversations.append(
{"role": "user", "content": image_token + "\n" + item["content"]}
)
else:
conversations.append({"role": "assistant", "content": item["content"]})
return {"conversations": conversations}


def main(args):
model, processor = load(
args.model_path, processor_config={"trust_remote_code": True}
)
image_processor = load_image_processor(args.model_path)

dataset = Dataset(
args.dataset,
model.config.__dict__,
processor,
image_processor=image_processor,
take=None,
split=None,
)
dataset = dataset.map(add_image_token)

optimizer = optim.Adam(learning_rate=args.learning_rate)
trainer = Trainer(model, optimizer)

list_of_modules = find_all_linear_names(model.language_model.model)
model = get_peft_model(model, list_of_modules)

model.vision_tower.freeze()
model.train()

for epoch in range(args.epochs):
for i in range(args.steps):
loss = trainer.train_step(
dataset[i * args.batch_size : (i + 1) * args.batch_size]
)
if i % args.print_every == 0:
print(f"Epoch {epoch} Step {i} Loss {loss}")

save_adapter(model, args.output_path)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Train NanoLLaVA model")
parser.add_argument(
"--model_path",
type=str,
default="mlx-community/nanoLLaVA-1.5-bf16",
help="Path to the pre-trained model",
)
parser.add_argument(
"--dataset", type=str, required=True, help="Path to the dataset"
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate for the optimizer",
)
parser.add_argument(
"--batch_size", type=int, default=2, help="Batch size for training"
)
parser.add_argument(
"--epochs", type=int, default=1, help="Number of epochs to train"
)
parser.add_argument(
"--steps", type=int, default=100, help="Number of steps per epoch"
)
parser.add_argument(
"--print_every", type=int, default=10, help="Print loss every n steps"
)
parser.add_argument(
"--output_path",
type=str,
default="nanollava_lora_adapter.safetensors",
help="Path to save the trained adapter",
)

args = parser.parse_args()
main(args)

0 comments on commit dab901c

Please sign in to comment.