diff --git a/examples/flux-control/train_control_flux.py b/examples/flux-control/train_control_flux.py index 1432e346f0ce..35f9a5f80342 100644 --- a/examples/flux-control/train_control_flux.py +++ b/examples/flux-control/train_control_flux.py @@ -795,7 +795,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) def unwrap_model(model): model = accelerator.unwrap_model(model) @@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): flux_transformer.to(torch.float32) flux_transformer.save_pretrained(args.output_dir) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: diff --git a/examples/flux-control/train_control_lora_flux.py b/examples/flux-control/train_control_lora_flux.py index 6d84e81d810a..b176a685c963 100644 --- a/examples/flux-control/train_control_lora_flux.py +++ b/examples/flux-control/train_control_lora_flux.py @@ -830,7 +830,7 @@ def main(args): flux_transformer.x_embedder = new_linear assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0) - flux_transformer.register_to_config(in_channels=initial_input_channels * 2) + flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels) if args.train_norm_layers: for name, param in flux_transformer.named_parameters(): @@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): transformer_lora_layers=transformer_lora_layers, ) + del flux_transformer + del text_encoding_pipeline + del vae + free_memory() + # Run a final round of validation. image_logs = None if args.validation_prompt is not None: