-
Notifications
You must be signed in to change notification settings - Fork 1
/
train.py
31 lines (24 loc) · 1.81 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import argparse
from diffusion_model import DiffusionModel
def parse_arguments():
"""Returns parsed arguments"""
parser = argparse.ArgumentParser(description="Train diffusion model")
parser.add_argument("--batch-size", type=int, default=128, help="Batch size for training")
parser.add_argument("--epochs", type=int, default=50, help="Number of epochs for training")
parser.add_argument("--dataset-name", type=str, default="mnist", help="Dataset name for training")
parser.add_argument("--device", type=str, default=None, help="Device for training")
parser.add_argument("--lr", type=float, default=1e-3, help="Learning rate for training")
parser.add_argument("--checkpoint-name", type=str, default=None, help="Checkpoint name for pre-training")
parser.add_argument("--timesteps", type=int, default=500, help="Timesteps T for DDPM training")
parser.add_argument("--beta1", type=float, default=0.0001, help="Hyperparameter for DDPM")
parser.add_argument("--beta2", type=float, default=0.02, help="Hyperparameter for DDPM training")
parser.add_argument("--checkpoint-save-dir", type=str, default=None, help="Directory to save checkpoints")
parser.add_argument("--image-save-dir", type=str, default=None, help="Directory to save generated images during training")
return parser.parse_args()
if __name__=="__main__":
args = parse_arguments()
diffusion_model = DiffusionModel(device=args.device, dataset_name=args.dataset_name,
checkpoint_name=args.checkpoint_name)
diffusion_model.train(batch_size=args.batch_size, n_epoch=args.epochs, lr=args.lr,
timesteps=args.timesteps, beta1=args.beta1, beta2=args.beta2,
checkpoint_save_dir=args.checkpoint_save_dir, image_save_dir=args.image_save_dir)