Skip to content

Commit

Permalink
have tokenizers pass dict output
Browse files Browse the repository at this point in the history
  • Loading branch information
jazcollins committed Oct 2, 2023
1 parent 505b850 commit 054d1ef
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 46 deletions.
20 changes: 7 additions & 13 deletions diffusion/callbacks/log_diffusion_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,18 @@ def eval_batch_end(self, state: State, logger: Logger):
else:
model = state.model

if self.tokenized_prompts is None:
tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
if model.sdxl:
if self.tokenized_prompts is None:
tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True, return_tensors='pt',
input_ids=True) # type: ignore
for p in self.prompts
]
tokenized_prompts_1 = torch.cat([tp[0] for tp in tokenized_prompts]).to(state.batch[self.text_key].device)
tokenized_prompts_2 = torch.cat([tp[1] for tp in tokenized_prompts]).to(state.batch[self.text_key].device)
self.tokenized_prompts = [tokenized_prompts_1, tokenized_prompts_2]
else:
if self.tokenized_prompts is None:
tokenized_prompts = [
model.tokenizer(p, padding='max_length', truncation=True,
return_tensors='pt')['input_ids'] # type: ignore
for p in self.prompts
]
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = torch.cat(tokenized_prompts)
self.tokenized_prompts = self.tokenized_prompts.to(state.batch[self.text_key].device) # type: ignore

# Generate images
Expand Down
23 changes: 10 additions & 13 deletions diffusion/datasets/image_caption.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,21 +129,18 @@ def __getitem__(self, index):
if isinstance(caption, List) and self.caption_selection == 'random':
caption = random.sample(caption, k=1)[0]

max_length = None if self.sdxl else self.tokenizer.model_max_length
tokenized_caption = self.tokenizer(
caption,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt')['input_ids']
if self.sdxl:
tokenized_captions = self.tokenizer(caption,
padding='max_length',
truncation=True,
return_tensors='pt',
input_ids=True)
tokenized_captions = [cap[0] for cap in tokenized_captions]
tokenized_caption = torch.stack(tokenized_captions)
tokenized_caption = [tokenized_cap.squeeze() for tokenized_cap in tokenized_caption]
tokenized_caption = torch.stack(tokenized_caption)
else:
tokenized_caption = self.tokenizer(
caption,
padding='max_length',
max_length=self.tokenizer.model_max_length, # type: ignore
truncation=True,
return_tensors='pt')['input_ids'][0]
tokenized_caption = tokenized_caption.squeeze()
out['image'] = img
out['captions'] = tokenized_caption
return out
Expand Down
17 changes: 9 additions & 8 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,7 @@ def forward(self, tokenized_text):
class SDXLTokenizer:
"""Wrapper around HuggingFace tokenizers for SDXL.
Tokenizes prompt with two tokenizers and returns the outputs as a list.
Tokenizes prompt with two tokenizers and returns the joined output.
Args:
model_name (str): Name of the model's text encoders to load. Defaults to 'stabilityai/stable-diffusion-xl-base-1.0'.
Expand All @@ -425,18 +425,19 @@ def __init__(self, model_name='stabilityai/stable-diffusion-xl-base-1.0'):
self.tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer')
self.tokenizer_2 = CLIPTokenizer.from_pretrained(model_name, subfolder='tokenizer_2')

def __call__(self, prompt, padding, truncation, return_tensors, input_ids=False):
def __call__(self, prompt, padding, truncation, return_tensors, max_length=None):
tokenized_output = self.tokenizer(prompt,
padding=padding,
max_length=self.tokenizer.model_max_length,
max_length=self.tokenizer.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
tokenized_output_2 = self.tokenizer_2(prompt,
padding=padding,
max_length=self.tokenizer_2.model_max_length,
max_length=self.tokenizer_2.model_max_length if max_length is None else max_length,
truncation=truncation,
return_tensors=return_tensors)
if input_ids:
tokenized_output = tokenized_output.input_ids
tokenized_output_2 = tokenized_output_2.input_ids
return [tokenized_output, tokenized_output_2]

# Add second tokenizer output to first tokenizer
for key in tokenized_output.keys():
tokenized_output[key] = [tokenized_output[key], tokenized_output_2[key]]
return tokenized_output
19 changes: 7 additions & 12 deletions diffusion/models/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,22 +491,17 @@ def _prepare_text_embeddings(self, prompt, tokenized_prompts, prompt_embeds, num
device = self.text_encoder.device
pooled_text_embeddings = None
if prompt_embeds is None:
max_length = None if self.sdxl else self.tokenizer.model_max_length
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
max_length=max_length,
truncation=True,
return_tensors='pt').input_ids
if self.sdxl:
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
truncation=True,
return_tensors='pt',
input_ids=True)
text_embeddings, pooled_text_embeddings = self.text_encoder(
[tokenized_prompts[0].to(device), tokenized_prompts[1].to(device)]) # type: ignore
else:
if tokenized_prompts is None:
tokenized_prompts = self.tokenizer(prompt,
padding='max_length',
max_length=self.tokenizer.model_max_length,
truncation=True,
return_tensors='pt').input_ids
text_embeddings = self.text_encoder(tokenized_prompts.to(device))[0] # type: ignore
else:
if self.sdxl:
Expand Down

0 comments on commit 054d1ef

Please sign in to comment.