Skip to content

Commit

Permalink
Update evaluation and inference code to handle other precisions and m…
Browse files Browse the repository at this point in the history
…odels (#179)
  • Loading branch information
coryMosaicML authored Nov 14, 2024
1 parent b0a094f commit ba8ca02
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 74 deletions.
15 changes: 9 additions & 6 deletions diffusion/datasets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class StreamingImageDataset(StreamingDataset):
Args:
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from.
``StreamingImageCaptionDataset`` uses either ``streams`` or ``remote``/``local``. Default:``None``.
remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``.
local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``.
remote (Union[str, Sequence[str]], optional): Remote directory (S3 or local filesystem) where dataset is
stored. Default: ``None``.
local (Union[str, Sequence[str]], optional): Local filesystem directory where dataset is cached during
operation. Default: ``None``.
transform (Callable, optional): The transforms to apply to the image. Default: ``None``.
image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``.
image_output_key (optional, str): Optional output key for the image. If none, the value of `image_key` will
Expand All @@ -41,8 +43,8 @@ class StreamingImageDataset(StreamingDataset):
def __init__(
self,
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
remote: Optional[Union[str, Sequence[str]]] = None,
local: Optional[Union[str, Sequence[str]]] = None,
transform: Optional[Callable] = None,
image_key: str = 'image',
image_output_key: Optional[str] = None,
Expand All @@ -54,10 +56,11 @@ def __init__(
streaming_kwargs.setdefault('shuffle_block_size', 1 << 18)
streaming_kwargs.setdefault('shuffle_algo', 'py1s')

# Make the streams if necessary
streams = make_streams(remote, local=local) if streams is None else streams

super().__init__(
streams=streams,
remote=remote,
local=local,
**streaming_kwargs,
)

Expand Down
14 changes: 4 additions & 10 deletions diffusion/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from composer.loggers import LoggerDestination
from composer.utils import reproducibility
from omegaconf import DictConfig, OmegaConf
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchmetrics.multimodal import CLIPScore

from diffusion.evaluation.clean_fid_eval import CleanFIDEvaluator
Expand All @@ -31,14 +31,8 @@ def evaluate(config: DictConfig) -> None:
# The model to evaluate
model: ComposerModel = hydra.utils.instantiate(config.model)

tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None

# The dataloader to use for evaluation
if tokenizer:
eval_dataloader = hydra.utils.instantiate(config.eval_dataloader, tokenizer=tokenizer)

else:
eval_dataloader: DataLoader = hydra.utils.instantiate(config.eval_dataloader)
# The dataset
dataset: Dataset = hydra.utils.instantiate(config.dataset)

# The CLIPScores metric to use for evaluation
clip_metric: CLIPScore = hydra.utils.instantiate(config.clip_metric)
Expand Down Expand Up @@ -88,7 +82,7 @@ def evaluate(config: DictConfig) -> None:
evaluator: CleanFIDEvaluator = hydra.utils.instantiate(
config.evaluator,
model=model,
eval_dataloader=eval_dataloader,
dataset=dataset,
clip_metric=clip_metric,
loggers=logger,
)
Expand Down
67 changes: 36 additions & 31 deletions diffusion/evaluation/clean_fid_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,10 @@
from composer.core import get_precision_context
from composer.loggers import LoggerDestination
from composer.utils import dist
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from torchmetrics.multimodal import CLIPScore
from torchvision.transforms.functional import to_pil_image
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
from tqdm.auto import tqdm
from transformers import PreTrainedTokenizerBase

os.environ['TOKENIZERS_PARALLELISM'] = 'false'

Expand All @@ -32,7 +31,7 @@ class CleanFIDEvaluator:
Args:
model (ComposerModel): The model to evaluate.
eval_dataloader (DataLoader): The dataloader to use for evaluation.
dataset (Dataset): The dataset to use the prompts from.
clip_metric (CLIPScore): The CLIPScore metric to use for evaluation.
load_path (str, optional): The path to load the model from. Default: ``None``.
guidance_scales (List[float]): The guidance scales to use for evaluation.
Expand All @@ -52,13 +51,14 @@ class CleanFIDEvaluator:
default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``.
default_negative_prompt (Optional[str]): An optional default negative prompt to add before each
negative prompt. Default: ``None``.
sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``.
additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method.
"""

def __init__(self,
model: ComposerModel,
eval_dataloader: DataLoader,
dataset: Dataset,
clip_metric: CLIPScore,
load_path: Optional[str] = None,
guidance_scales: Optional[List[float]] = None,
Expand All @@ -75,10 +75,10 @@ def __init__(self,
prompts: Optional[List[str]] = None,
default_prompt: Optional[str] = None,
default_negative_prompt: Optional[str] = None,
sdxl_conditioning: bool = False,
additional_generate_kwargs: Optional[Dict] = None):
self.model = model
self.tokenizer: PreTrainedTokenizerBase = model.tokenizer
self.eval_dataloader = eval_dataloader
self.dataset = dataset
self.clip_metric = clip_metric
self.load_path = load_path
self.guidance_scales = guidance_scales if guidance_scales is not None else [1.0]
Expand All @@ -89,20 +89,19 @@ def __init__(self,
self.loggers = loggers
self.seed = seed
self.output_dir = output_dir
self.num_samples = num_samples if num_samples is not None else float('inf')
self.num_samples = num_samples
self.precision = precision
self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater']
self.default_prompt = default_prompt
self.default_negative_prompt = default_negative_prompt
self.sdxl_conditioning = sdxl_conditioning
self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {}
self.sdxl = model.sdxl

# Load the model
trainer = Trainer(model=self.model,
load_path=self.load_path,
load_weights_only=True,
load_strict_model_weights=load_strict_model_weights,
eval_dataloader=self.eval_dataloader,
seed=self.seed,
loggers=self.loggers)
self.trainer = trainer
Expand Down Expand Up @@ -139,18 +138,27 @@ def _generate_images(self, guidance_scale: float):

# Storage for prompts
prompts = {}
# Iterate over the eval dataloader
num_batches = len(self.eval_dataloader)
starting_seed = self.seed + num_batches * dist.get_local_rank()
for batch_id, batch in tqdm(enumerate(self.eval_dataloader)):
# Break if enough samples have been generated
if batch_id * self.batch_size * dist.get_world_size() >= self.num_samples:
break

real_images = batch[self.image_key]
tokenized_captions = batch[self.caption_key]
# Get the prompts from the tokens
text_captions = self.tokenizer.batch_decode(tokenized_captions, skip_special_tokens=True)
# Partition the dataset across the ranks
dataset_len = self.dataset.num_samples # type: ignore
# Truncate the dataset if num_samples is specified
if self.num_samples is not None and self.num_samples <= dataset_len:
dataset_len = self.num_samples
elif self.num_samples is not None and self.num_samples > dataset_len:
raise ValueError(f'num_samples {self.num_samples} is greater than the dataset length {dataset_len}.')
samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size())
start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank())
end_idx = start_idx + samples_per_rank
if dist.get_global_rank() < remainder:
end_idx += 1
print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.')
# Iterate over the dataset
for sample_id in tqdm(range(start_idx, end_idx)):
# Set a unique seed for this sample to ensure reproducible but different randomness
seed = self.seed + sample_id
# Image and caption come from the dataset. Note the caption is untokenized
sample = self.dataset[sample_id]
real_images = pil_to_tensor(sample[self.image_key]).unsqueeze(0) / 255.0
text_captions = sample[self.caption_key]
# Add default prompts if specified
augmented_captions = text_captions
augmented_negative_prompt = None
Expand All @@ -159,15 +167,12 @@ def _generate_images(self, guidance_scale: float):
if self.default_negative_prompt:
augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions]

if self.sdxl:
crop_params = batch['cond_crops_coords_top_left']
input_size_params = batch['cond_original_size']
if self.sdxl_conditioning:
crop_params = torch.tensor([0, 0]).unsqueeze(0)
input_size_params = torch.tensor([self.size, self.size]).unsqueeze(0)
else:
crop_params = None
input_size_params = None

# Ensure a new seed for each batch, as randomness in model.generate is fixed.
seed = starting_seed + batch_id
# Generate images from the captions
with get_precision_context(self.precision):
generated_images = self.model.generate(prompt=augmented_captions,
Expand All @@ -188,11 +193,11 @@ def _generate_images(self, guidance_scale: float):
f'Images are expected to be in the range [0, 1]. Got max {real_images.max()} and min {real_images.min()}'
)
for i, img in enumerate(real_images):
to_pil_image(img).save(f'{real_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')
prompts[f'{batch_id}_{i}_rank_{dist.get_local_rank()}'] = text_captions[i]
to_pil_image(img).save(f'{real_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png')
prompts[f'{sample_id}_rank_{dist.get_local_rank()}'] = text_captions[i]
# Save the generated images
for i, img in enumerate(generated_images):
to_pil_image(img).save(f'{gen_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png')
to_pil_image(img).save(f'{gen_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png')

# Save the prompts as json
json.dump(prompts, open(f'{real_image_path}/prompts_rank_{dist.get_local_rank()}.json', 'w'))
Expand Down
5 changes: 4 additions & 1 deletion diffusion/evaluation/generate_geneval_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class GenevalImageGenerator:
load_path (str, optional): The path to load the model from. Default: ``None``.
local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``.
load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``.
precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``.
guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``.
height (int): The height of the generated images. Default: ``1024``.
width (int): The width of the generated images. Default: ``1024``.
Expand All @@ -46,6 +47,7 @@ def __init__(self,
load_path: Optional[str] = None,
local_checkpoint_path: str = '/tmp/model.pt',
load_strict_model_weights: bool = True,
precision: str = 'amp_fp16',
guidance_scale: float = 7.0,
height: int = 1024,
width: int = 1024,
Expand Down Expand Up @@ -77,6 +79,7 @@ def __init__(self,
self.load_path = load_path
self.local_checkpoint_path = local_checkpoint_path
self.load_strict_model_weights = load_strict_model_weights
self.precision = precision
self.guidance_scale = guidance_scale
self.height = height
self.width = width
Expand Down Expand Up @@ -148,7 +151,7 @@ def generate(self):
**self.additional_generate_kwargs).images[0]
img = generated_image
else:
with get_precision_context('amp_fp16'):
with get_precision_context(self.precision):
generated_image = self.model.generate(prompt=caption,
height=self.height,
width=self.width,
Expand Down
14 changes: 12 additions & 2 deletions diffusion/inference/inference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,23 @@ class ModelInference():
model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'.
local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'.
strict (bool): Whether to load the model weights strictly. Default: False.
dtype: The data type to use. One of [`float32`, `float16`, `bfloat16`]. Default: `bfloat16`.
**model_kwargs: Keyword arguments to pass to the model initialization.
"""

def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs):
def __init__(self,
model_name,
local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH,
strict=False,
dtype='bfloat16',
**model_kwargs):
self.device = torch.cuda.current_device()
model_factory = getattr(diffusion.models, model_name)
model = model_factory(**model_kwargs)
dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16}
if dtype not in dtype_map:
raise ValueError(f'Invalid dtype: {dtype}. Must be one of {list(dtype_map.keys())}')
self.dtype = dtype_map[dtype]

if 'pretrained' in model_kwargs and model_kwargs['pretrained']:
pass
Expand Down Expand Up @@ -290,7 +300,7 @@ def predict(self, model_requests: List[Dict[str, Any]]):
raise RuntimeError('There must be the same number of negative prompts as prompts.')

# Generate images
with torch.cuda.amp.autocast(True):
with torch.cuda.amp.autocast(True, dtype=self.dtype):
imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu()

# Send as bytes
Expand Down
8 changes: 3 additions & 5 deletions diffusion/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@

import torch
from composer.devices import DeviceGPU
from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel
from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler,
UNet2DConditionModel)
from peft import LoraConfig
from torchmetrics import MeanSquaredError
from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, PretrainedConfig
Expand Down Expand Up @@ -770,16 +771,13 @@ def precomputed_text_latent_diffusion(
'beta_schedule': 'scaled_linear',
'trained_betas': None,
'prediction_type': prediction_type,
'interpolation_type': 'linear',
'use_karras_sigmas': False,
'timestep_spacing': 'leading',
'steps_offset': 1,
'rescale_betas_zero_snr': False,
}

if inference_noise_scheduler_params is not None:
inference_scheduler_params.update(inference_noise_scheduler_params)
inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params)
inference_noise_scheduler = DPMSolverMultistepScheduler(**inference_scheduler_params)

# Shift noise scheduler to correct for resolution changes
noise_scheduler = shift_noise_schedule(noise_scheduler,
Expand Down
39 changes: 20 additions & 19 deletions diffusion/models/precomputed_text_latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,28 +203,29 @@ def decode_latents(self, latents):
def encode_text(self, text, device):
assert self.t5_tokenizer is not None and self.t5_encoder is not None
assert self.clip_tokenizer is not None and self.clip_encoder is not None
# Encode with T5
t5_tokenizer_out = self.t5_tokenizer(text,
padding='max_length',
max_length=self.t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device)
t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device)
t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0]
# Encode with CLIP
clip_tokenizer_out = self.clip_tokenizer(text,
with torch.autocast(device_type='cuda', enabled=False):
# Encode with T5
t5_tokenizer_out = self.t5_tokenizer(text,
padding='max_length',
max_length=self.clip_tokenizer.model_max_length,
max_length=self.t5_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device)
clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device)
clip_out = self.clip_encoder(input_ids=clip_tokenized_captions,
attention_mask=clip_attn_mask,
output_hidden_states=True)
clip_embed = clip_out.hidden_states[-2]
pooled_embeddings = clip_out[1]
t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device)
t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device)
t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0]
# Encode with CLIP
clip_tokenizer_out = self.clip_tokenizer(text,
padding='max_length',
max_length=self.clip_tokenizer.model_max_length,
truncation=True,
return_tensors='pt')
clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device)
clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device)
clip_out = self.clip_encoder(input_ids=clip_tokenized_captions,
attention_mask=clip_attn_mask,
output_hidden_states=True)
clip_embed = clip_out.hidden_states[-2]
pooled_embeddings = clip_out[1]
return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings

def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tensor, t5_mask: torch.Tensor,
Expand Down

0 comments on commit ba8ca02

Please sign in to comment.