diff --git a/diffusion/datasets/image.py b/diffusion/datasets/image.py index 20a2c329..aaaee8dc 100644 --- a/diffusion/datasets/image.py +++ b/diffusion/datasets/image.py @@ -12,6 +12,8 @@ from torch.utils.data import DataLoader from torchvision import transforms +from diffusion.datasets.utils import make_streams + log = logging.getLogger(__name__) # Disable PIL max image size limit @@ -93,6 +95,9 @@ def build_streaming_image_dataloader( transform: Optional[List[Callable]] = None, image_key: str = 'image', image_output_key: Optional[str] = 'image', + proportion: Optional[list] = None, + repeat: Optional[list] = None, + choose: Optional[list] = None, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -106,6 +111,9 @@ def build_streaming_image_dataloader( image_key (str): Key associated with the image in the streaming dataset. Default: ``'image'``. image_output_key (optional, str): Optional output key for the image. If none, the value of `image_key` will be used. Default: ``image``. + proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``. + repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``. + choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -115,21 +123,8 @@ def build_streaming_image_dataloader( if dataloader_kwargs is None: dataloader_kwargs = {} - # Check types for remote and local - if isinstance(remote, str) and isinstance(local, str): - remote, local = [remote], [local] - elif isinstance(remote, Sequence) and isinstance(local, Sequence): - if len(remote) != len(local): - raise ValueError( - f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}') - else: - raise ValueError( - f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.') - - # Create a Stream for each (remote, local) pair - streams = [] - for r, l in zip(remote, local): - streams.append(Stream(remote=r, local=l)) + # Set up streams + streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose) if transform is None: transform = [transforms.ToTensor()] diff --git a/diffusion/datasets/image_caption.py b/diffusion/datasets/image_caption.py index b65d21e0..81939c61 100644 --- a/diffusion/datasets/image_caption.py +++ b/diffusion/datasets/image_caption.py @@ -6,7 +6,6 @@ import logging import random from io import BytesIO -from pathlib import Path from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -17,6 +16,7 @@ from torchvision import transforms from diffusion.datasets.laion.transforms import LargestCenterSquare, RandomCropAspectRatioTransorm, RandomCropSquare +from diffusion.datasets.utils import make_streams from diffusion.models.text_encoder import MultiTokenizer log = logging.getLogger(__name__) @@ -182,6 +182,9 @@ def build_streaming_image_caption_dataloader( crop_type: Optional[str] = 'square', zero_dropped_captions: bool = True, sdxl_conditioning: bool = False, + proportion: Optional[list] = None, + repeat: Optional[list] = None, + choose: Optional[list] = None, streaming_kwargs: Optional[Dict] = None, dataloader_kwargs: Optional[Dict] = None, ): @@ -213,6 +216,9 @@ def build_streaming_image_caption_dataloader( Default: ``'square'``. zero_dropped_captions (bool): If True, zero out text embeddings for dropped captions. Default: ``True``. sdxl_conditioning (bool): Whether or not to include SDXL microconditioning in a sample. Default: `False`. + proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``. + repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``. + choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``. streaming_kwargs (dict, optional): Additional arguments to pass to the ``StreamingDataset``. Default: ``None``. dataloader_kwargs (dict, optional): Additional arguments to pass to the ``DataLoader``. Default: ``None``. """ @@ -231,25 +237,8 @@ def build_streaming_image_caption_dataloader( if dataloader_kwargs is None: dataloader_kwargs = {} - # Check types for remote and local - - if isinstance(remote, str): - remote = [remote] - if isinstance(local, str): - local = [local] - if not local: - local = [_make_default_local_path(r) for r in remote] - if isinstance(remote, Sequence) and isinstance(local, Sequence): - if len(remote) != len(local): - ValueError( - f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}') - else: - ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.') - - # Create a Stream for each (remote, local) pair - streams = [] - for r, l in zip(remote, local): - streams.append(Stream(remote=r, local=l)) + # Set up streams + streams = make_streams(remote, local=local, proportion=proportion, repeat=repeat, choose=choose) # Set the crop to apply if crop_type == 'square': @@ -290,7 +279,3 @@ def build_streaming_image_caption_dataloader( ) return dataloader - - -def _make_default_local_path(remote_path): - return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:]))) diff --git a/diffusion/datasets/utils.py b/diffusion/datasets/utils.py new file mode 100644 index 00000000..c1d1f197 --- /dev/null +++ b/diffusion/datasets/utils.py @@ -0,0 +1,64 @@ +# Copyright 2022 MosaicML Diffusion authors +# SPDX-License-Identifier: Apache-2.0 + +"""Useful functions for dealing with streaming datasets.""" + +from pathlib import Path +from typing import Sequence + +from streaming import Stream + + +def make_streams(remote, local=None, proportion=None, repeat=None, choose=None): + """Helper function to create a list of Stream objects from a set of remotes and stream weights. + + Args: + remote (Union[str, Sequence[str]]): The remote path or paths to stream from. + local (Union[str, Sequence[str]], optional): The local path or paths to cache the data. If not provided, the + default local path is used. Default: ``None``. + proportion (list, optional): Specifies how to sample this Stream relative to other Streams. Default: ``None``. + repeat (list, optional): Specifies the degree to which a Stream is upsampled or downsampled. Default: ``None``. + choose (list, optional): Specifies the number of samples to choose from a Stream. Default: ``None``. + + Returns: + List[Stream]: A list of Stream objects. + """ + remote, local = _make_remote_and_local_sequences(remote, local) + proportion, repeat, choose = _make_weighting_sequences(remote, proportion, repeat, choose) + + streams = [] + for i, (r, l) in enumerate(zip(remote, local)): + streams.append(Stream(remote=r, local=l, proportion=proportion[i], repeat=repeat[i], choose=choose[i])) + return streams + + +def _make_remote_and_local_sequences(remote, local=None): + if isinstance(remote, str): + remote = [remote] + if isinstance(local, str): + local = [local] + if not local: + local = [_make_default_local_path(r) for r in remote] + + if isinstance(remote, Sequence) and isinstance(local, Sequence): + if len(remote) != len(local): + ValueError( + f'remote and local Sequences must be the same length, got lengths {len(remote)} and {len(local)}') + else: + ValueError(f'remote and local must be both Strings or Sequences, got types {type(remote)} and {type(local)}.') + return remote, local + + +def _make_default_local_path(remote_path): + return str(Path(*['/tmp'] + list(Path(remote_path).parts[1:]))) + + +def _make_weighting_sequences(remote, proportion=None, repeat=None, choose=None): + weights = {'proportion': proportion, 'repeat': repeat, 'choose': choose} + for name, weight in weights.items(): + if weight is not None and len(remote) != len(weight): + ValueError(f'{name} must be the same length as remote, got lengths {len(remote)} and {len(weight)}') + proportion = weights['proportion'] if weights['proportion'] is not None else [None] * len(remote) + repeat = weights['repeat'] if weights['repeat'] is not None else [None] * len(remote) + choose = weights['choose'] if weights['choose'] is not None else [None] * len(remote) + return proportion, repeat, choose