-
Notifications
You must be signed in to change notification settings - Fork 35
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
162 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# lora.py - NanoLLaVA LoRA Training Script | ||
|
||
## Overview | ||
|
||
`lora.py` is a Python script for fine-tuning a NanoLLaVA model using Low-Rank Adaptation (LoRA). This script allows you to train the model on your custom dataset, adjusting various parameters through command-line arguments. | ||
|
||
## Requirements | ||
|
||
- Python 3.7+ | ||
- MLX VLM library | ||
- Required Python packages: `argparse`, `mlx_vlm`, `mlx` | ||
|
||
## Usage | ||
|
||
To use the script, run it from the command line with the desired arguments: | ||
|
||
``` | ||
python lora.py --dataset /path/to/your/dataset [other options] | ||
``` | ||
|
||
## Arguments | ||
|
||
The script accepts the following command-line arguments: | ||
|
||
- `--model_path`: Path to the pre-trained model (default: "mlx-community/nanoLLaVA-1.5-bf16") | ||
- `--dataset`: Path to your dataset (required) | ||
- `--learning_rate`: Learning rate for the optimizer (default: 1e-4) | ||
- `--batch_size`: Batch size for training (default: 2) | ||
- `--epochs`: Number of epochs to train (default: 1) | ||
- `--steps`: Number of steps per epoch (default: 100) | ||
- `--print_every`: Print loss every n steps (default: 10) | ||
- `--output_path`: Path to save the trained adapter (default: "nanollava_lora_adapter.safetensors") | ||
|
||
## Example | ||
|
||
Here's an example of how to run the script with custom parameters: | ||
|
||
``` | ||
python lora.py --dataset /path/to/your/dataset --epochs 2 --steps 200 --batch_size 4 --learning_rate 5e-5 | ||
``` | ||
|
||
This command will: | ||
- Use the dataset at `/path/to/your/dataset` | ||
- Train for 2 epochs | ||
- Perform 200 steps per epoch | ||
- Use a batch size of 4 | ||
- Set the learning rate to 5e-5 | ||
|
||
## Output | ||
|
||
The script will print the training loss at regular intervals (defined by `--print_every`). After training, it will save the LoRA adapter to the specified output path. | ||
|
||
## Note | ||
|
||
Make sure you have the necessary permissions to read the dataset and write the output file. Also, ensure that your system has sufficient computational resources to handle the specified batch size and model. | ||
|
||
## Contributing | ||
|
||
Feel free to submit issues or pull requests if you find any bugs or have suggestions for improvements. | ||
|
||
## License | ||
|
||
[Specify the license here, e.g., MIT, Apache 2.0, etc.] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import argparse | ||
|
||
import mlx.optimizers as optim | ||
|
||
from mlx_vlm.trainer import Dataset, Trainer | ||
from mlx_vlm.trainer.lora import * | ||
from mlx_vlm.trainer.utils import * | ||
from mlx_vlm.utils import load, load_image_processor | ||
|
||
|
||
def add_image_token(items, image_token="<image>"): | ||
conversations = [] | ||
for item in items["conversations"]: | ||
if item["role"] == "user": | ||
if item["content"].startswith(image_token): | ||
conversations.append({"role": "user", "content": item["content"]}) | ||
else: | ||
conversations.append( | ||
{"role": "user", "content": image_token + "\n" + item["content"]} | ||
) | ||
else: | ||
conversations.append({"role": "assistant", "content": item["content"]}) | ||
return {"conversations": conversations} | ||
|
||
|
||
def main(args): | ||
model, processor = load( | ||
args.model_path, processor_config={"trust_remote_code": True} | ||
) | ||
image_processor = load_image_processor(args.model_path) | ||
|
||
dataset = Dataset( | ||
args.dataset, | ||
model.config.__dict__, | ||
processor, | ||
image_processor=image_processor, | ||
take=None, | ||
split=None, | ||
) | ||
dataset = dataset.map(add_image_token) | ||
|
||
optimizer = optim.Adam(learning_rate=args.learning_rate) | ||
trainer = Trainer(model, optimizer) | ||
|
||
list_of_modules = find_all_linear_names(model.language_model.model) | ||
model = get_peft_model(model, list_of_modules) | ||
|
||
model.vision_tower.freeze() | ||
model.train() | ||
|
||
for epoch in range(args.epochs): | ||
for i in range(args.steps): | ||
loss = trainer.train_step( | ||
dataset[i * args.batch_size : (i + 1) * args.batch_size] | ||
) | ||
if i % args.print_every == 0: | ||
print(f"Epoch {epoch} Step {i} Loss {loss}") | ||
|
||
save_adapter(model, args.output_path) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser(description="Train NanoLLaVA model") | ||
parser.add_argument( | ||
"--model_path", | ||
type=str, | ||
default="mlx-community/nanoLLaVA-1.5-bf16", | ||
help="Path to the pre-trained model", | ||
) | ||
parser.add_argument( | ||
"--dataset", type=str, required=True, help="Path to the dataset" | ||
) | ||
parser.add_argument( | ||
"--learning_rate", | ||
type=float, | ||
default=1e-4, | ||
help="Learning rate for the optimizer", | ||
) | ||
parser.add_argument( | ||
"--batch_size", type=int, default=2, help="Batch size for training" | ||
) | ||
parser.add_argument( | ||
"--epochs", type=int, default=1, help="Number of epochs to train" | ||
) | ||
parser.add_argument( | ||
"--steps", type=int, default=100, help="Number of steps per epoch" | ||
) | ||
parser.add_argument( | ||
"--print_every", type=int, default=10, help="Print loss every n steps" | ||
) | ||
parser.add_argument( | ||
"--output_path", | ||
type=str, | ||
default="nanollava_lora_adapter.safetensors", | ||
help="Path to save the trained adapter", | ||
) | ||
|
||
args = parser.parse_args() | ||
main(args) |