Skip to content

Commit

Permalink
[training] fix: registration of out_channels in the control flux scri…
Browse files Browse the repository at this point in the history
…pts. (#10367)

* fix: registration of out_channels in the control flux scripts.

* free memory.
  • Loading branch information
sayakpaul authored Dec 24, 2024
1 parent 023b0e0 commit 825979d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
7 changes: 6 additions & 1 deletion examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -795,7 +795,7 @@ def main(args):
flux_transformer.x_embedder = new_linear

assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)

def unwrap_model(model):
model = accelerator.unwrap_model(model)
Expand Down Expand Up @@ -1166,6 +1166,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
flux_transformer.to(torch.float32)
flux_transformer.save_pretrained(args.output_dir)

del flux_transformer
del text_encoding_pipeline
del vae
free_memory()

# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
Expand Down
7 changes: 6 additions & 1 deletion examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ def main(args):
flux_transformer.x_embedder = new_linear

assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2)
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)

if args.train_norm_layers:
for name, param in flux_transformer.named_parameters():
Expand Down Expand Up @@ -1319,6 +1319,11 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
transformer_lora_layers=transformer_lora_layers,
)

del flux_transformer
del text_encoding_pipeline
del vae
free_memory()

# Run a final round of validation.
image_logs = None
if args.validation_prompt is not None:
Expand Down

0 comments on commit 825979d

Please sign in to comment.