Skip to content

Commit

Permalink
Only make a text encoder mask if mask_pad_tokens is true (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
coryMosaicML authored Jun 6, 2024
1 parent 7d3a7cc commit 93a5469
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,9 @@ def set_rng_generator(self, rng_generator: torch.Generator):

def forward(self, batch):
latents, text_embeds, text_pooled_embeds, attention_mask, encoder_attention_mask = None, None, None, None, None
if 'attention_mask' in batch:
if 'attention_mask' in batch and self.mask_pad_tokens:
attention_mask = batch['attention_mask'] # mask for text encoders
# text mask for U-Net
if self.mask_pad_tokens:
encoder_attention_mask = _create_unet_attention_mask(attention_mask)
encoder_attention_mask = _create_unet_attention_mask(attention_mask) # text mask for U-Net

# Use latents if specified and available. When specified, they might not exist during eval
if self.precomputed_latents and self.image_latents_key in batch and self.text_latents_key in batch:
Expand Down

0 comments on commit 93a5469

Please sign in to comment.