Skip to content

Commit

Permalink
Fix OOM error in SDXL Fine-Tuning validation stage (huggingface#1134)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Socek <daniel.socek@intel.com>
  • Loading branch information
dsocek committed Jul 12, 2024
1 parent f0568ff commit c06b920
Showing 1 changed file with 16 additions and 25 deletions.
41 changes: 16 additions & 25 deletions examples/stable-diffusion/training/train_text_to_image_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,7 @@ def unwrap_model(model, training=False):
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
pipeline = None

# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
Expand Down Expand Up @@ -1382,22 +1383,21 @@ def compute_time_ids(original_size, crops_coords_top_left):
ema_unet.copy_to(unet.parameters())

# create pipeline
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder=("vae" if args.pretrained_vae_model_name_or_path is None else None),
revision=args.revision,
variant=args.variant,
)
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
use_habana=True,
use_hpu_graphs=args.use_hpu_graphs_for_inference,
gaudi_config=args.gaudi_config_name,
)
if pipeline is None:
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
use_habana=True,
use_hpu_graphs=args.use_hpu_graphs_for_inference,
gaudi_config=args.gaudi_config_name,
)
else:
# vae and text encoders are frozen, only need to update unet
pipeline.unet = unwrap_model(unet)

if args.prediction_type is not None:
scheduler_args = {"prediction_type": args.prediction_type}
pipeline.scheduler = pipeline.scheduler.from_config(pipeline.scheduler.config, **scheduler_args)
Expand Down Expand Up @@ -1433,8 +1433,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
}
)

del pipeline

if t0 is not None:
duration = time.perf_counter() - t0 - (checkpoint_time if args.adjust_throughput else 0)
ttt = time.perf_counter() - t_start
Expand All @@ -1457,13 +1455,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
ema_unet.copy_to(unet.parameters())

# Serialize pipeline.
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = GaudiStableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=unet,
Expand Down

0 comments on commit c06b920

Please sign in to comment.