diff --git a/examples/stable-diffusion/training/train_text_to_image_sdxl.py b/examples/stable-diffusion/training/train_text_to_image_sdxl.py index 9a7193db24..38e08a8541 100644 --- a/examples/stable-diffusion/training/train_text_to_image_sdxl.py +++ b/examples/stable-diffusion/training/train_text_to_image_sdxl.py @@ -1136,6 +1136,7 @@ def unwrap_model(model, training=False): logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 + pipeline = None # Potentially load in the weights and states from a previous save if args.resume_from_checkpoint: @@ -1382,22 +1383,21 @@ def compute_time_ids(original_size, crops_coords_top_left): ema_unet.copy_to(unet.parameters()) # create pipeline - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder=("vae" if args.pretrained_vae_model_name_or_path is None else None), - revision=args.revision, - variant=args.variant, - ) - pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( - args.pretrained_model_name_or_path, - vae=vae, - unet=unwrap_model(unet), - revision=args.revision, - variant=args.variant, - use_habana=True, - use_hpu_graphs=args.use_hpu_graphs_for_inference, - gaudi_config=args.gaudi_config_name, - ) + if pipeline is None: + pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unwrap_model(unet), + revision=args.revision, + variant=args.variant, + use_habana=True, + use_hpu_graphs=args.use_hpu_graphs_for_inference, + gaudi_config=args.gaudi_config_name, + ) + else: + # vae and text encoders are frozen, only need to update unet + pipeline.unet = unwrap_model(unet) + if args.prediction_type is not None: scheduler_args = {"prediction_type": args.prediction_type} pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args) @@ -1433,8 +1433,6 @@ def compute_time_ids(original_size, crops_coords_top_left): } ) - del pipeline - if t0 is not None: duration = time.perf_counter() - t0 - (checkpoint_time if args.adjust_throughput else 0) ttt = time.perf_counter() - t_start @@ -1457,13 +1455,6 @@ def compute_time_ids(original_size, crops_coords_top_left): ema_unet.copy_to(unet.parameters()) # Serialize pipeline. - vae = AutoencoderKL.from_pretrained( - vae_path, - subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, - revision=args.revision, - variant=args.variant, - torch_dtype=weight_dtype, - ) pipeline = GaudiStableDiffusionXLPipeline.from_pretrained( args.pretrained_model_name_or_path, unet=unet,