From ba8ca02d93433c22bbeebb1955ac049391557ee7 Mon Sep 17 00:00:00 2001 From: coryMosaicML <83666378+coryMosaicML@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:34:31 -0800 Subject: [PATCH] Update evaluation and inference code to handle other precisions and models (#179) --- diffusion/datasets/image.py | 15 +++-- diffusion/evaluate.py | 14 ++-- diffusion/evaluation/clean_fid_eval.py | 67 ++++++++++--------- .../evaluation/generate_geneval_images.py | 5 +- diffusion/inference/inference_model.py | 14 +++- diffusion/models/models.py | 8 +-- .../precomputed_text_latent_diffusion.py | 39 +++++------ 7 files changed, 88 insertions(+), 74 deletions(-) diff --git a/diffusion/datasets/image.py b/diffusion/datasets/image.py index aaaee8dc..fef11ca7 100644 --- a/diffusion/datasets/image.py +++ b/diffusion/datasets/image.py @@ -26,8 +26,10 @@ class StreamingImageDataset(StreamingDataset): Args: streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from. ``StreamingImageCaptionDataset`` uses either ``streams`` or ``remote``/``local``. Default:``None``. - remote (str, optional): Remote directory (S3 or local filesystem) where dataset is stored. Default: ``None``. - local (str, optional): Local filesystem directory where dataset is cached during operation. Default: ``None``. + remote (Union[str, Sequence[str]], optional): Remote directory (S3 or local filesystem) where dataset is + stored. Default: ``None``. + local (Union[str, Sequence[str]], optional): Local filesystem directory where dataset is cached during + operation. Default: ``None``. transform (Callable, optional): The transforms to apply to the image. Default: ``None``. image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. image_output_key (optional, str): Optional output key for the image. If none, the value of `image_key` will @@ -41,8 +43,8 @@ class StreamingImageDataset(StreamingDataset): def __init__( self, streams: Optional[Sequence[Stream]] = None, - remote: Optional[str] = None, - local: Optional[str] = None, + remote: Optional[Union[str, Sequence[str]]] = None, + local: Optional[Union[str, Sequence[str]]] = None, transform: Optional[Callable] = None, image_key: str = 'image', image_output_key: Optional[str] = None, @@ -54,10 +56,11 @@ def __init__( streaming_kwargs.setdefault('shuffle_block_size', 1 << 18) streaming_kwargs.setdefault('shuffle_algo', 'py1s') + # Make the streams if necessary + streams = make_streams(remote, local=local) if streams is None else streams + super().__init__( streams=streams, - remote=remote, - local=local, **streaming_kwargs, ) diff --git a/diffusion/evaluate.py b/diffusion/evaluate.py index 9baad0a8..f72fae52 100644 --- a/diffusion/evaluate.py +++ b/diffusion/evaluate.py @@ -14,7 +14,7 @@ from composer.loggers import LoggerDestination from composer.utils import reproducibility from omegaconf import DictConfig, OmegaConf -from torch.utils.data import DataLoader +from torch.utils.data import Dataset from torchmetrics.multimodal import CLIPScore from diffusion.evaluation.clean_fid_eval import CleanFIDEvaluator @@ -31,14 +31,8 @@ def evaluate(config: DictConfig) -> None: # The model to evaluate model: ComposerModel = hydra.utils.instantiate(config.model) - tokenizer = model.tokenizer if hasattr(model, 'tokenizer') else None - - # The dataloader to use for evaluation - if tokenizer: - eval_dataloader = hydra.utils.instantiate(config.eval_dataloader, tokenizer=tokenizer) - - else: - eval_dataloader: DataLoader = hydra.utils.instantiate(config.eval_dataloader) + # The dataset + dataset: Dataset = hydra.utils.instantiate(config.dataset) # The CLIPScores metric to use for evaluation clip_metric: CLIPScore = hydra.utils.instantiate(config.clip_metric) @@ -88,7 +82,7 @@ def evaluate(config: DictConfig) -> None: evaluator: CleanFIDEvaluator = hydra.utils.instantiate( config.evaluator, model=model, - eval_dataloader=eval_dataloader, + dataset=dataset, clip_metric=clip_metric, loggers=logger, ) diff --git a/diffusion/evaluation/clean_fid_eval.py b/diffusion/evaluation/clean_fid_eval.py index eb198222..2b4221a7 100644 --- a/diffusion/evaluation/clean_fid_eval.py +++ b/diffusion/evaluation/clean_fid_eval.py @@ -14,11 +14,10 @@ from composer.core import get_precision_context from composer.loggers import LoggerDestination from composer.utils import dist -from torch.utils.data import DataLoader +from torch.utils.data import Dataset from torchmetrics.multimodal import CLIPScore -from torchvision.transforms.functional import to_pil_image +from torchvision.transforms.functional import pil_to_tensor, to_pil_image from tqdm.auto import tqdm -from transformers import PreTrainedTokenizerBase os.environ['TOKENIZERS_PARALLELISM'] = 'false' @@ -32,7 +31,7 @@ class CleanFIDEvaluator: Args: model (ComposerModel): The model to evaluate. - eval_dataloader (DataLoader): The dataloader to use for evaluation. + dataset (Dataset): The dataset to use the prompts from. clip_metric (CLIPScore): The CLIPScore metric to use for evaluation. load_path (str, optional): The path to load the model from. Default: ``None``. guidance_scales (List[float]): The guidance scales to use for evaluation. @@ -52,13 +51,14 @@ class CleanFIDEvaluator: default_prompt (Optional[str]): An optional default prompt to add before each eval prompt. Default: ``None``. default_negative_prompt (Optional[str]): An optional default negative prompt to add before each negative prompt. Default: ``None``. + sdxl_conditioning (bool): Whether or not to include SDXL conditioning in the evaluation. Default: ``False``. additional_generate_kwargs (Dict, optional): Additional keyword arguments to pass to the model.generate method. """ def __init__(self, model: ComposerModel, - eval_dataloader: DataLoader, + dataset: Dataset, clip_metric: CLIPScore, load_path: Optional[str] = None, guidance_scales: Optional[List[float]] = None, @@ -75,10 +75,10 @@ def __init__(self, prompts: Optional[List[str]] = None, default_prompt: Optional[str] = None, default_negative_prompt: Optional[str] = None, + sdxl_conditioning: bool = False, additional_generate_kwargs: Optional[Dict] = None): self.model = model - self.tokenizer: PreTrainedTokenizerBase = model.tokenizer - self.eval_dataloader = eval_dataloader + self.dataset = dataset self.clip_metric = clip_metric self.load_path = load_path self.guidance_scales = guidance_scales if guidance_scales is not None else [1.0] @@ -89,20 +89,19 @@ def __init__(self, self.loggers = loggers self.seed = seed self.output_dir = output_dir - self.num_samples = num_samples if num_samples is not None else float('inf') + self.num_samples = num_samples self.precision = precision self.prompts = prompts if prompts is not None else ['A shiba inu wearing a blue sweater'] self.default_prompt = default_prompt self.default_negative_prompt = default_negative_prompt + self.sdxl_conditioning = sdxl_conditioning self.additional_generate_kwargs = additional_generate_kwargs if additional_generate_kwargs is not None else {} - self.sdxl = model.sdxl # Load the model trainer = Trainer(model=self.model, load_path=self.load_path, load_weights_only=True, load_strict_model_weights=load_strict_model_weights, - eval_dataloader=self.eval_dataloader, seed=self.seed, loggers=self.loggers) self.trainer = trainer @@ -139,18 +138,27 @@ def _generate_images(self, guidance_scale: float): # Storage for prompts prompts = {} - # Iterate over the eval dataloader - num_batches = len(self.eval_dataloader) - starting_seed = self.seed + num_batches * dist.get_local_rank() - for batch_id, batch in tqdm(enumerate(self.eval_dataloader)): - # Break if enough samples have been generated - if batch_id * self.batch_size * dist.get_world_size() >= self.num_samples: - break - - real_images = batch[self.image_key] - tokenized_captions = batch[self.caption_key] - # Get the prompts from the tokens - text_captions = self.tokenizer.batch_decode(tokenized_captions, skip_special_tokens=True) + # Partition the dataset across the ranks + dataset_len = self.dataset.num_samples # type: ignore + # Truncate the dataset if num_samples is specified + if self.num_samples is not None and self.num_samples <= dataset_len: + dataset_len = self.num_samples + elif self.num_samples is not None and self.num_samples > dataset_len: + raise ValueError(f'num_samples {self.num_samples} is greater than the dataset length {dataset_len}.') + samples_per_rank, remainder = divmod(dataset_len, dist.get_world_size()) + start_idx = dist.get_global_rank() * samples_per_rank + min(remainder, dist.get_global_rank()) + end_idx = start_idx + samples_per_rank + if dist.get_global_rank() < remainder: + end_idx += 1 + print(f'Rank {dist.get_global_rank()} processing samples {start_idx} to {end_idx} of {dataset_len} total.') + # Iterate over the dataset + for sample_id in tqdm(range(start_idx, end_idx)): + # Set a unique seed for this sample to ensure reproducible but different randomness + seed = self.seed + sample_id + # Image and caption come from the dataset. Note the caption is untokenized + sample = self.dataset[sample_id] + real_images = pil_to_tensor(sample[self.image_key]).unsqueeze(0) / 255.0 + text_captions = sample[self.caption_key] # Add default prompts if specified augmented_captions = text_captions augmented_negative_prompt = None @@ -159,15 +167,12 @@ def _generate_images(self, guidance_scale: float): if self.default_negative_prompt: augmented_negative_prompt = [f'{self.default_negative_prompt}' for _ in text_captions] - if self.sdxl: - crop_params = batch['cond_crops_coords_top_left'] - input_size_params = batch['cond_original_size'] + if self.sdxl_conditioning: + crop_params = torch.tensor([0, 0]).unsqueeze(0) + input_size_params = torch.tensor([self.size, self.size]).unsqueeze(0) else: crop_params = None input_size_params = None - - # Ensure a new seed for each batch, as randomness in model.generate is fixed. - seed = starting_seed + batch_id # Generate images from the captions with get_precision_context(self.precision): generated_images = self.model.generate(prompt=augmented_captions, @@ -188,11 +193,11 @@ def _generate_images(self, guidance_scale: float): f'Images are expected to be in the range [0, 1]. Got max {real_images.max()} and min {real_images.min()}' ) for i, img in enumerate(real_images): - to_pil_image(img).save(f'{real_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png') - prompts[f'{batch_id}_{i}_rank_{dist.get_local_rank()}'] = text_captions[i] + to_pil_image(img).save(f'{real_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png') + prompts[f'{sample_id}_rank_{dist.get_local_rank()}'] = text_captions[i] # Save the generated images for i, img in enumerate(generated_images): - to_pil_image(img).save(f'{gen_image_path}/{batch_id}_{i}_rank_{dist.get_local_rank()}.png') + to_pil_image(img).save(f'{gen_image_path}/{sample_id}_rank_{dist.get_local_rank()}.png') # Save the prompts as json json.dump(prompts, open(f'{real_image_path}/prompts_rank_{dist.get_local_rank()}.json', 'w')) diff --git a/diffusion/evaluation/generate_geneval_images.py b/diffusion/evaluation/generate_geneval_images.py index 8b64e46a..2632ba7c 100644 --- a/diffusion/evaluation/generate_geneval_images.py +++ b/diffusion/evaluation/generate_geneval_images.py @@ -27,6 +27,7 @@ class GenevalImageGenerator: load_path (str, optional): The path to load the model from. Default: ``None``. local_checkpoint_path (str, optional): The local path to save the model checkpoint. Default: ``'/tmp/model.pt'``. load_strict_model_weights (bool): Whether or not to strict load model weights. Default: ``True``. + precision (str): The precision to use for evaluation. Default: ``'amp_fp16'``. guidance_scale (float): The guidance scale to use for evaluation. Default: ``7.0``. height (int): The height of the generated images. Default: ``1024``. width (int): The width of the generated images. Default: ``1024``. @@ -46,6 +47,7 @@ def __init__(self, load_path: Optional[str] = None, local_checkpoint_path: str = '/tmp/model.pt', load_strict_model_weights: bool = True, + precision: str = 'amp_fp16', guidance_scale: float = 7.0, height: int = 1024, width: int = 1024, @@ -77,6 +79,7 @@ def __init__(self, self.load_path = load_path self.local_checkpoint_path = local_checkpoint_path self.load_strict_model_weights = load_strict_model_weights + self.precision = precision self.guidance_scale = guidance_scale self.height = height self.width = width @@ -148,7 +151,7 @@ def generate(self): **self.additional_generate_kwargs).images[0] img = generated_image else: - with get_precision_context('amp_fp16'): + with get_precision_context(self.precision): generated_image = self.model.generate(prompt=caption, height=self.height, width=self.width, diff --git a/diffusion/inference/inference_model.py b/diffusion/inference/inference_model.py index d6925baf..e675c102 100644 --- a/diffusion/inference/inference_model.py +++ b/diffusion/inference/inference_model.py @@ -235,13 +235,23 @@ class ModelInference(): model_name (str): Name of the model from `diffusion.models` to load. Ex: for stable diffusion xl, use 'stable_diffusion_xl'. local_checkpoint_path (str): Path to the local checkpoint. Default: '/tmp/model.pt'. strict (bool): Whether to load the model weights strictly. Default: False. + dtype: The data type to use. One of [`float32`, `float16`, `bfloat16`]. Default: `bfloat16`. **model_kwargs: Keyword arguments to pass to the model initialization. """ - def __init__(self, model_name, local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, strict=False, **model_kwargs): + def __init__(self, + model_name, + local_checkpoint_path: str = LOCAL_CHECKPOINT_PATH, + strict=False, + dtype='bfloat16', + **model_kwargs): self.device = torch.cuda.current_device() model_factory = getattr(diffusion.models, model_name) model = model_factory(**model_kwargs) + dtype_map = {'float32': torch.float32, 'float16': torch.float16, 'bfloat16': torch.bfloat16} + if dtype not in dtype_map: + raise ValueError(f'Invalid dtype: {dtype}. Must be one of {list(dtype_map.keys())}') + self.dtype = dtype_map[dtype] if 'pretrained' in model_kwargs and model_kwargs['pretrained']: pass @@ -290,7 +300,7 @@ def predict(self, model_requests: List[Dict[str, Any]]): raise RuntimeError('There must be the same number of negative prompts as prompts.') # Generate images - with torch.cuda.amp.autocast(True): + with torch.cuda.amp.autocast(True, dtype=self.dtype): imgs = self.model.generate(prompt=prompts, negative_prompt=negative_prompts, **generate_kwargs).cpu() # Send as bytes diff --git a/diffusion/models/models.py b/diffusion/models/models.py index 82c1480f..8b3edf6a 100644 --- a/diffusion/models/models.py +++ b/diffusion/models/models.py @@ -9,7 +9,8 @@ import torch from composer.devices import DeviceGPU -from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, EulerDiscreteScheduler, UNet2DConditionModel +from diffusers import (AutoencoderKL, DDIMScheduler, DDPMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, + UNet2DConditionModel) from peft import LoraConfig from torchmetrics import MeanSquaredError from transformers import AutoModel, AutoTokenizer, CLIPTextModel, CLIPTokenizer, PretrainedConfig @@ -770,16 +771,13 @@ def precomputed_text_latent_diffusion( 'beta_schedule': 'scaled_linear', 'trained_betas': None, 'prediction_type': prediction_type, - 'interpolation_type': 'linear', - 'use_karras_sigmas': False, 'timestep_spacing': 'leading', - 'steps_offset': 1, 'rescale_betas_zero_snr': False, } if inference_noise_scheduler_params is not None: inference_scheduler_params.update(inference_noise_scheduler_params) - inference_noise_scheduler = EulerDiscreteScheduler(**inference_scheduler_params) + inference_noise_scheduler = DPMSolverMultistepScheduler(**inference_scheduler_params) # Shift noise scheduler to correct for resolution changes noise_scheduler = shift_noise_schedule(noise_scheduler, diff --git a/diffusion/models/precomputed_text_latent_diffusion.py b/diffusion/models/precomputed_text_latent_diffusion.py index 31acdb2f..fba48123 100644 --- a/diffusion/models/precomputed_text_latent_diffusion.py +++ b/diffusion/models/precomputed_text_latent_diffusion.py @@ -203,28 +203,29 @@ def decode_latents(self, latents): def encode_text(self, text, device): assert self.t5_tokenizer is not None and self.t5_encoder is not None assert self.clip_tokenizer is not None and self.clip_encoder is not None - # Encode with T5 - t5_tokenizer_out = self.t5_tokenizer(text, - padding='max_length', - max_length=self.t5_tokenizer.model_max_length, - truncation=True, - return_tensors='pt') - t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device) - t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device) - t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0] - # Encode with CLIP - clip_tokenizer_out = self.clip_tokenizer(text, + with torch.autocast(device_type='cuda', enabled=False): + # Encode with T5 + t5_tokenizer_out = self.t5_tokenizer(text, padding='max_length', - max_length=self.clip_tokenizer.model_max_length, + max_length=self.t5_tokenizer.model_max_length, truncation=True, return_tensors='pt') - clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device) - clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device) - clip_out = self.clip_encoder(input_ids=clip_tokenized_captions, - attention_mask=clip_attn_mask, - output_hidden_states=True) - clip_embed = clip_out.hidden_states[-2] - pooled_embeddings = clip_out[1] + t5_tokenized_captions = t5_tokenizer_out['input_ids'].to(device) + t5_attn_mask = t5_tokenizer_out['attention_mask'].to(torch.bool).to(device) + t5_embed = self.t5_encoder(input_ids=t5_tokenized_captions, attention_mask=t5_attn_mask)[0] + # Encode with CLIP + clip_tokenizer_out = self.clip_tokenizer(text, + padding='max_length', + max_length=self.clip_tokenizer.model_max_length, + truncation=True, + return_tensors='pt') + clip_tokenized_captions = clip_tokenizer_out['input_ids'].to(device) + clip_attn_mask = clip_tokenizer_out['attention_mask'].to(torch.bool).to(device) + clip_out = self.clip_encoder(input_ids=clip_tokenized_captions, + attention_mask=clip_attn_mask, + output_hidden_states=True) + clip_embed = clip_out.hidden_states[-2] + pooled_embeddings = clip_out[1] return t5_embed, clip_embed, t5_attn_mask, clip_attn_mask, pooled_embeddings def prepare_text_embeddings(self, t5_embed: torch.Tensor, clip_embed: torch.Tensor, t5_mask: torch.Tensor,