From 06fb0c8b328a38ac69296604dde4754f9f1d39d9 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 11 Nov 2020 12:42:07 +0000 Subject: [PATCH 01/14] Refactor data._helpers --- lightly/data/_helpers.py | 129 ++++++++++----------------------------- lightly/data/_image.py | 100 ++++++++++++++++++++++++++++++ 2 files changed, 133 insertions(+), 96 deletions(-) create mode 100644 lightly/data/_image.py diff --git a/lightly/data/_helpers.py b/lightly/data/_helpers.py index 934198f08..e5e4afa70 100644 --- a/lightly/data/_helpers.py +++ b/lightly/data/_helpers.py @@ -4,105 +4,33 @@ # All Rights Reserved import os -import torchvision.datasets as datasets +from torchvision import datasets + +from lightly.data._image import DatasetFolder + +try: + from lightly.data._video import VideoDataset + VIDEO_DATASET_AVAILABLE = True +except Exception: + VIDEO_DATASET_AVAILABLE = False + from lightly.data._image_loaders import default_loader IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') +VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi') -def _make_dataset(directory, extensions=None, is_valid_file=None): - """Return a list of all image files with targets in the directory - - Args: - directory: (str) Root directory path - (should not contain subdirectories!) - extensions: (List[str]) List of allowed extensions - is_valid_file: (callable) Used to check corrupt files - - Returns: - List of instance tuples: (path_i, target_i = 0) +def _contains_videos(root, extensions): """ - if extensions is not None: - def _is_valid_file(filename): - return filename.lower().endswith(extensions) - - instances = [] - for fname in os.listdir(directory): - - if not _is_valid_file(fname): - continue - - path = os.path.join(directory, fname) - item = (path, 0) - instances.append(item) - - return instances - - -class DatasetFolder(datasets.VisionDataset): - - def __init__(self, root, loader, extensions=None, transform=None, - target_transform=None, is_valid_file=None): - """Constructor based on torchvisions DatasetFolder - (https://pytorch.org/docs/stable/torchvision/datasets.html#datasetfolder) - - Args: - root: (str) Root directory path - loader: (callable) Function that loads file at path - extensions: (List[str]) List of allowed extensions - transform: Function that takes a PIL image and returns - transformed version - target_transform: As transform but for targets - is_valid_file: (callable) Used to check corrupt files - - Raises: - RuntimeError: If no supported files are found in root. - - """ - - super(DatasetFolder, self).__init__(root, transform=transform, - target_transform=target_transform) - - samples = _make_dataset(self.root, extensions, is_valid_file) - if len(samples) == 0: - msg = 'Found 0 files in folder: {}\n'.format(self.root) - if extensions is not None: - msg += 'Supported extensions are: {}'.format( - ','.join(extensions)) - raise RuntimeError(msg) - - self.loader = loader - self.extensions = extensions - - self.samples = samples - self.targets = [s[1] for s in samples] - - def __getitem__(self, index): - """Returns item at index - - Args: - index: (int) Index - - Returns: - tuple: (sample, target) where target is 0 - - """ - - path, target = self.samples[index] - sample = self.loader(path) - if self.transform is not None: - sample = self.transform(sample) - if self.target_transform is not None: - target = self.target_transform(target) - - return sample, target - - def __len__(self): - return len(self.samples) + """ + list_dir = os.listdir(root) + is_video = \ + [f.lower().endswith(extensions) for f in list_dir] + return any(is_video) def _contains_subdirs(root): @@ -115,11 +43,9 @@ def _contains_subdirs(root): True if root contains subdirectories else false """ - list_dir = os.listdir(root) - is_dir = [ - os.path.isdir(os.path.join(root, f)) for f in list_dir - ] + is_dir = \ + [os.path.isdir(os.path.join(root, f)) for f in list_dir] return any(is_dir) @@ -134,10 +60,21 @@ def _load_dataset_from_folder(root, transform): Dataset consisting of images in the root directory """ - if _contains_subdirs(root): - dataset = datasets.ImageFolder(root, transform=transform) + contains_videos = _contains_videos(root, VIDEO_EXTENSIONS) + if contains_videos and not VIDEO_DATASET_AVAILABLE: + raise ValueError(f'The input directory {root} contains videos ' + 'but the VideoDataset is not available.') + + if contains_videos: + dataset = VideoDataset(root, + extensions=VIDEO_EXTENSIONS, + transform=transform) + elif _contains_subdirs(root): + dataset = datasets.ImageFolder(root, + transform=transform) else: - dataset = DatasetFolder(root, default_loader, + dataset = DatasetFolder(root, + default_loader, extensions=IMG_EXTENSIONS, transform=transform) diff --git a/lightly/data/_image.py b/lightly/data/_image.py new file mode 100644 index 000000000..567614e85 --- /dev/null +++ b/lightly/data/_image.py @@ -0,0 +1,100 @@ +""" """ + +# +# + +import os +import torchvision.datasets as datasets + + +def _make_dataset(directory, extensions=None, is_valid_file=None): + """Return a list of all image files with targets in the directory + + Args: + directory: (str) Root directory path + (should not contain subdirectories!) + extensions: (List[str]) List of allowed extensions + is_valid_file: (callable) Used to check corrupt files + + Returns: + List of instance tuples: (path_i, target_i = 0) + + """ + + if extensions is not None: + def _is_valid_file(filename): + return filename.lower().endswith(extensions) + + instances = [] + for fname in os.listdir(directory): + + if not _is_valid_file(fname): + continue + + path = os.path.join(directory, fname) + item = (path, 0) + instances.append(item) + + return instances + + +class DatasetFolder(datasets.VisionDataset): + + def __init__(self, root, loader, extensions=None, transform=None, + target_transform=None, is_valid_file=None): + """Constructor based on torchvisions DatasetFolder + (https://pytorch.org/docs/stable/torchvision/datasets.html#datasetfolder) + + Args: + root: (str) Root directory path + loader: (callable) Function that loads file at path + extensions: (List[str]) List of allowed extensions + transform: Function that takes a PIL image and returns + transformed version + target_transform: As transform but for targets + is_valid_file: (callable) Used to check corrupt files + + Raises: + RuntimeError: If no supported files are found in root. + + """ + + super(DatasetFolder, self).__init__(root, transform=transform, + target_transform=target_transform) + + samples = _make_dataset(self.root, extensions, is_valid_file) + if len(samples) == 0: + msg = 'Found 0 files in folder: {}\n'.format(self.root) + if extensions is not None: + msg += 'Supported extensions are: {}'.format( + ','.join(extensions)) + raise RuntimeError(msg) + + self.loader = loader + self.extensions = extensions + + self.samples = samples + self.targets = [s[1] for s in samples] + + def __getitem__(self, index): + """Returns item at index + + Args: + index: (int) Index + + Returns: + tuple: (sample, target) where target is 0 + + """ + + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) \ No newline at end of file From 6f5a69dfe515bb4cbd8b065aed626638769970d6 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 11 Nov 2020 12:42:32 +0000 Subject: [PATCH 02/14] Add first version of video dataset --- lightly/data/_video.py | 137 ++++++++++++++++++++++++++++++++++++++++ lightly/data/dataset.py | 10 ++- 2 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 lightly/data/_video.py diff --git a/lightly/data/_video.py b/lightly/data/_video.py new file mode 100644 index 000000000..2e7b2601b --- /dev/null +++ b/lightly/data/_video.py @@ -0,0 +1,137 @@ +""" """ + +# +# + +import os +from PIL import Image +from torchvision import datasets + +from torchvision.io import read_video, read_video_timestamps + + +def _video_loader(path, timestamp, pts_unit='sec'): + """ + + """ + # random access read from video + frame, _, _ = read_video(path, + start_pts=timestamp, + end_pts=timestamp, + pts_unit=pts_unit) + # read_video returns tensor of shape 1 x W x H x C + frame = frame.squeeze() + # convert to PIL image + # TODO: can it be on CUDA? -> need to move it to CPU first + image = Image.fromarray(frame.numpy()) + + return image + + +def _make_dataset(directory, extensions=None, is_valid_file=None, pts_unit='sec'): + """ + + """ + + if extensions is not None: + def _is_valid_file(filename): + return filename.lower().endswith(extensions) + + # find all instances + instances = [] + for fname in os.listdir(directory): + + if not _is_valid_file(fname): + continue + + path = os.path.join(directory, fname) + instances.append(path) + + # get timestamps + timestamps, fpss = [], [] + for instance in instances: + ts, fps = read_video_timestamps(instance, pts_unit=pts_unit) + timestamps.append(ts) + fpss.append(fps) + + # get offsets + offsets = [len(ts) for ts in timestamps] + offsets = [0] + offsets[:-1] + for i in range(1, len(offsets)): + offsets[i] = offsets[i-1] + offsets[i] # cumsum + + return instances, timestamps, offsets, fpss + + +class VideoDataset(datasets.VisionDataset): + """ + + """ + + def __init__(self, root, loader=_video_loader, extensions=None, + transform=None, target_transform=None, is_valid_file=None): + + super(VideoDataset, self).__init__(root, transform=transform, + target_transform=target_transform) + + videos, video_timestamps, offsets, fpss = _make_dataset( + self.root, extensions, is_valid_file) + + self.extensions = extensions + self.loader = loader + + self.videos = videos + self.video_timestamps = video_timestamps + self.offsets = offsets + self.fpss = fpss + + def __getitem__(self, index): + """ + + """ + if index < 0 or index >= self.__len__(): + raise IndexError(f'Index {index} is out of bounds for VideoDataset' + f' of size {self.__len__()}.') + + # find video of the frame + i = 0 + while i < len(self.offsets) - 1: + if self.offsets[i] >= index: + break + i = i + 1 + + # find and return the frame as PIL image + target = i + sample = self.loader(self.videos[i], + self.video_timestamps[i][index - self.offsets[i]]) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + """ + + """ + return sum((len(ts) for ts in self.video_timestamps)) + + def _get_filename(self, index): + """ + + """ + if index < 0 or index >= self.__len__(): + raise IndexError(f'Index {index} is out of bounds for VideoDataset' + f' of size {self.__len__()}.') + + # find video of the frame + i = 0 + while i < len(self.offsets) - 1: + if self.offsets[i] >= index: + break + i = i + 1 + + filename = '.'.join(self.videos[i].split('.')[:-1]) + timestamp = float(self.video_timestamps[i][index - self.offsets[i]]) + return '%s-%.8fs.png' % (filename, timestamp) diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index 7d7c7a7ac..c175e9223 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -1,4 +1,4 @@ -""" Lightly Dataset """ +""" Lightly Dataset """ # Copyright (c) 2020. Lightly AG and its affiliates. # All Rights Reserved @@ -12,6 +12,12 @@ from lightly.data._helpers import _load_dataset from lightly.data._helpers import DatasetFolder +try: + from lightly.data._video import VideoDataset + VIDEO_DATASET_AVAILABLE = True +except Exception: + VIDEO_DATASET_AVAILABLE = False + class LightlyDataset(data.Dataset): """Provides a uniform data interface for the embedding models. @@ -91,6 +97,8 @@ def _get_filename_by_index(self, index) -> str: elif isinstance(self.dataset, DatasetFolder): full_path = self.dataset.samples[index][0] return os.path.relpath(full_path, self.root_folder) + elif VIDEO_DATASET_AVAILABLE and isinstance(self.dataset, VideoDataset): + return self.dataset._get_filename(index) else: return str(index) From a7faed633f2535f4c1b825a060b603ed64c8ecfa Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 14:47:27 +0000 Subject: [PATCH 03/14] Fix faulty route to get user quota Add missing suffix "/quota" and return the status code along with the maxDatasetSize. This allows error handling outside of the get function. --- lightly/api/routes/users/service.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/lightly/api/routes/users/service.py b/lightly/api/routes/users/service.py index e340a36a8..d0160298a 100644 --- a/lightly/api/routes/users/service.py +++ b/lightly/api/routes/users/service.py @@ -32,15 +32,16 @@ def get_quota(token: str): A token to identify the user. Returns: - A dictionary with the quota for the user. + The quota for the user and the status code of the response. """ - dst_url = _prefix() + dst_url = _prefix() + '/quota' payload = { 'token': token } - try: - response = requests.get(dst_url, params=payload) - return response.json() - except Exception: - return {'maxDatasetSize': LIGHTLY_MAXIMUM_DATASET_SIZE} + response = requests.get(dst_url, params=payload) + status_code = response.status_code + if status_code == 200: + return response.json()['maxDatasetSize'], status_code + else: + return LIGHTLY_MAXIMUM_DATASET_SIZE, status_code From e241e5d09f1381e2cd13d0a2d0236371c3235ae2 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 14:51:08 +0000 Subject: [PATCH 04/14] Handle missing requirements case The video dataset has additional requirements. If these are not installed and a VideoDataset is instantiated, the LightlyDataset will raise an error notifying the user of the missing requirements. --- lightly/data/_video.py | 28 ++++++++++--------- lightly/data/dataset.py | 59 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 72 insertions(+), 15 deletions(-) diff --git a/lightly/data/_video.py b/lightly/data/_video.py index 2e7b2601b..8d105bc38 100644 --- a/lightly/data/_video.py +++ b/lightly/data/_video.py @@ -4,6 +4,7 @@ # import os +import av from PIL import Image from torchvision import datasets @@ -94,11 +95,9 @@ def __getitem__(self, index): f' of size {self.__len__()}.') # find video of the frame - i = 0 - while i < len(self.offsets) - 1: - if self.offsets[i] >= index: - break - i = i + 1 + i = len(self.offsets) - 1 + while (self.offsets[i] > index): + i = i - 1 # find and return the frame as PIL image target = i @@ -117,7 +116,7 @@ def __len__(self): """ return sum((len(ts) for ts in self.video_timestamps)) - def _get_filename(self, index): + def get_filename(self, index): """ """ @@ -126,12 +125,15 @@ def _get_filename(self, index): f' of size {self.__len__()}.') # find video of the frame - i = 0 - while i < len(self.offsets) - 1: - if self.offsets[i] >= index: - break - i = i + 1 + i = len(self.offsets) - 1 + while (self.offsets[i] > index): + i = i - 1 - filename = '.'.join(self.videos[i].split('.')[:-1]) + filename = self.videos[i] + filename = os.path.relpath(filename, self.root) + + splits = filename.split('.') + video_format = splits[-1] + video_name = '.'.join(splits[:-1]) timestamp = float(self.video_timestamps[i][index - self.offsets[i]]) - return '%s-%.8fs.png' % (filename, timestamp) + return '%s-%.8fs-%s.png' % (video_name, timestamp, video_format) diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index c175e9223..6c608eecc 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -4,7 +4,9 @@ # All Rights Reserved import os -from typing import List +import shutil +from PIL import Image +from typing import List, Union import torch.utils.data as data import torchvision.datasets as datasets @@ -78,6 +80,59 @@ def __init__(self, if from_folder: self.root_folder = from_folder + def dump_image(self, output_dir: str, filename: str, format: Union[str, None] = None): + """ + + """ + index = self.get_filenames().index(filename) + image, _ = self.dataset[index] + + source = os.path.join(self.root_folder, filename) + target = os.path.join(output_dir, filename) + + dirname = os.path.dirname(target) + os.makedirs(dirname, exist_ok=True) + + if os.path.isfile(source): + # copy the file from the source to the target + shutil.copyfile(source, target) + else: + # the source is not a file (e.g. when loading a video frame) + try: + # try to save the image with the specified format or + # derive the format from the filename (if format=None) + image.save(target, format=format) + except ValueError: + # could not determine format from filename + image.save(os.path.join(output_dir, filename), format='png') + + def dump(self, output_dir: str, filenames: Union[List[str], None] = None, format: Union[str, None] = None): + """Saves all specified images to the output directory. + + Args: + output_dir: + TODO + filenames: + TODO + format: + TODO + + """ + # make sure no transforms are applied to the images + if self.dataset.transform is not None: + pass + + # create directory if it doesn't exist yet + os.makedirs(output_dir, exist_ok=True) + + # get all filenames + if filenames is None: + filenames = self.get_filenames() + + # dump images + for filename in filenames: + self.dump_image(output_dir, filename, format=format) + def get_filenames(self) -> List[str]: """Returns all filenames in the dataset. @@ -98,7 +153,7 @@ def _get_filename_by_index(self, index) -> str: full_path = self.dataset.samples[index][0] return os.path.relpath(full_path, self.root_folder) elif VIDEO_DATASET_AVAILABLE and isinstance(self.dataset, VideoDataset): - return self.dataset._get_filename(index) + return self.dataset.get_filename(index) else: return str(index) From d19415df7c466591328b615f825833e8280ba8b2 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 14:53:43 +0000 Subject: [PATCH 05/14] Refactor upload/download to work with video data --- lightly/api/upload.py | 226 +++++++++++++++++++++++++----------- lightly/api/utils.py | 41 +++---- lightly/cli/download_cli.py | 15 ++- 3 files changed, 189 insertions(+), 93 deletions(-) diff --git a/lightly/api/upload.py b/lightly/api/upload.py index d5e492c7c..55eb966f9 100644 --- a/lightly/api/upload.py +++ b/lightly/api/upload.py @@ -18,6 +18,7 @@ from lightly.api.utils import get_thumbnail_from_img from lightly.api.utils import check_image +from lightly.api.utils import check_filename from lightly.api.utils import PIL_to_bytes from lightly.api.utils import put_request @@ -77,7 +78,8 @@ def upload_embeddings_from_csv(path_to_embeddings: str, dataset_id: str, token: str, max_upload: int = 32, - embedding_name: str = 'default'): + embedding_name: str = 'default', + verbose: bool = True): """Uploads embeddings from a csv file to the cloud solution. The csv file should be in the format specified by lightly. See the @@ -136,6 +138,9 @@ def upload_embeddings_from_csv(path_to_embeddings: str, batch['embeddings'] = data['embeddings'][left:right] embedding_batches[i] = batch + if verbose: + print('Uploading embeddings:') + pbar = tqdm.tqdm(unit='embs', total=n_embeddings) for i, batch in enumerate(embedding_batches): _upload_single_batch( @@ -143,31 +148,37 @@ def upload_embeddings_from_csv(path_to_embeddings: str, pbar.update(len(batch['embeddings'])) -def _upload_single_image(input_dir, fname, mode, dataset_id, token): +def _upload_single_image(image, + label, + filename, + dataset_id, + token, + mode): """Uploads a single image to the Lightly platform. """ - # random delay of uniform[0, 0.01] seconds to prevent API bursts - rnd_delay = random.random() * 0.01 - time.sleep(rnd_delay) - - # get PIL image handles, metadata, and check if corrupted - metadata, is_corrupted = check_image( - os.path.join(input_dir, fname) - ) - - # filename is too long, cannot accept this file - if not metadata: + # check whether the filename is too long + basename = filename + if not check_filename(basename): + msg = (f'Filename {basename} is longer than the allowed maximum of ' + 'characters and will be skipped.') + warnings.warn(msg) return False - # upload sample - basename = fname + # calculate metadata, and check if corrupted + metadata = check_image(image) + + # generate thumbnail if necessary thumbname = None - if mode in ['full', 'thumbnails'] and not is_corrupted: + thumbnail = None + if mode == 'thumbnails' and not metadata['is_corrupted']: thumbname = '.'.join(basename.split('.')[:-1]) + '_thumb.webp' + thumbnail = get_thumbnail_from_img(image) + # upload sample with metadata sample_upload_success = True + try: sample_id = routes.users.datasets.samples.post( basename, thumbname, metadata, dataset_id, token @@ -177,28 +188,25 @@ def _upload_single_image(input_dir, fname, mode, dataset_id, token): # upload thumbnail thumbnail_upload_success = True - if mode == 'thumbnails' and not is_corrupted: + if mode == 'thumbnails' and not metadata['is_corrupted']: try: # try to get signed url for thumbnail signed_url = routes.users.datasets.samples. \ get_presigned_upload_url( thumbname, dataset_id, sample_id, token) - - # try to create thumbnail - image_path = os.path.join(input_dir, fname) - with Image.open(image_path) as temp_image: - thumbnail = get_thumbnail_from_img(temp_image) # try to upload thumbnail upload_file_with_signed_url( PIL_to_bytes(thumbnail, ext='webp', quality=70), signed_url ) + # close thumbnail + thumbnail.close() except RuntimeError: thumbnail_upload_success = False # upload full image image_upload_success = True - if mode == 'full' and not is_corrupted: + if mode == 'full' and not metadata['is_corrupted']: try: # try to get signed url for image signed_url = routes.users.datasets.samples. \ @@ -206,12 +214,12 @@ def _upload_single_image(input_dir, fname, mode, dataset_id, token): basename, dataset_id, sample_id, token) # try to upload image - image_path = os.path.join(input_dir, fname) - with open(image_path, 'rb') as temp_image: - upload_file_with_signed_url( - temp_image, - signed_url - ) + upload_file_with_signed_url( + PIL_to_bytes(image), + signed_url + ) + # close image + image.close() except RuntimeError: image_upload_success = False @@ -221,17 +229,17 @@ def _upload_single_image(input_dir, fname, mode, dataset_id, token): return success -def upload_images_from_folder(path_to_folder: str, - dataset_id: str, - token: str, - max_workers: int = 8, - max_requests: int = 32, - mode: str = 'thumbnails'): +def upload_dataset(dataset: LightlyDataset, + dataset_id: str, + token: str, + max_workers: int = 8 , + mode: str = 'thumbnails', + verbose: bool = True): """Uploads images from a directory to the Lightly cloud solution. Args: - path_to_folder: - Path to the folder containing the images. + dataset + The dataset to upload dataset_id: The unique identifier for the dataset. token: @@ -247,21 +255,24 @@ def upload_images_from_folder(path_to_folder: str, Raises: ValueError if dataset is too large. + RuntimeError if the connection to the server failed. RuntimeError if dataset already has an initial tag. """ - bds = LightlyDataset(from_folder=path_to_folder) - fnames = bds.get_filenames() - # check the allowed dataset size - api_max_dataset_size = routes.users.get_quota(token)['maxDatasetSize'] + api_max_dataset_size, status_code = routes.users.get_quota(token) max_dataset_size = min(api_max_dataset_size, LIGHTLY_MAXIMUM_DATASET_SIZE) - if len(fnames) > max_dataset_size: - msg = f'Your dataset has {len(fnames)} samples which' + if len(dataset) > max_dataset_size: + msg = f'Your dataset has {len(dataset)} samples which' msg += f' is more than the allowed maximum of {max_dataset_size}' raise ValueError(msg) + # check whether connection to server was possible + if status_code != 200: + msg = f'Connection to server failed with status code {status_code}.' + raise RuntimeError(msg) + # check whether the dataset alreadys has existing tags tags = routes.users.datasets.tags.get(dataset_id, token) if len(tags) > 0: @@ -270,36 +281,73 @@ def upload_images_from_folder(path_to_folder: str, msg += f'{tag_names}' raise RuntimeError(msg) - # split the samples in batches of equal size - n_batches = len(fnames) // max_requests - n_batches = n_batches + 1 if len(fnames) % max_requests else n_batches - fname_batches = [ - list(islice(fnames, i * max_requests, (i + 1) * max_requests)) - for i in range(n_batches) - ] + # handle the case where len(dataset) < max_workers + max_workers = min(len(dataset), max_workers) - chunksize = max(max_requests // max_workers, 1) - executor = ThreadPoolExecutor(max_workers=max_workers) + # define lambda function for parallel upload + """ + def lambda_(index): + # wait for a random amount of time to prevent server overload + time.sleep(index * 0.1) + # find batch start and end + start = index * batch_size + stop = min((index + 1) * batch_size, len(dataset)) + # progress bar + pbar = tqdm.tqdm(unit='imgs', total=stop - start) + # upload all images in the batch + success_list = [] + for i in range(start, stop): + time.sleep(random.random() * 0.01) + image, label, filename = dataset[i] + success = _upload_single_image( + image=image, + label=label, + filename=filename, + dataset_id=dataset_id, + token=token, + mode=mode, + ) + success_list.append(success) + pbar.update(1) + pbar.refresh() + return all(success_list) + """ # upload the samples - pbar = tqdm.tqdm(unit='imgs', total=len(fnames)) - for i, batch in enumerate(fname_batches): - - mapped = list(executor.map(lambda x: _upload_single_image( - input_dir=path_to_folder, - fname=x, - mode=mode, + if verbose: + print(f'Uploading images (with {max_workers} workers).', flush=True) + + pbar = tqdm.tqdm(unit='imgs', total=len(dataset)) + tqdm_lock = tqdm.tqdm.get_lock() + + def lambda_(i): + # load image + image, label, filename = dataset[i] + # upload image + success = _upload_single_image( + image=image, + label=label, + filename=filename, dataset_id=dataset_id, token=token, - ), batch, chunksize=chunksize)) - - if not all(mapped): - msg = 'Warning: Unsuccessful upload(s) in batch {}! '.format(i) - msg += 'This could cause problems when uploading embeddings.' - msg += 'Failed at file: {}'.format(mapped.index(False)) - warnings.warn(msg) - - pbar.update(len(batch)) + mode=mode, + ) + # update the progress bar + tqdm_lock.acquire() # lock + pbar.update(1) # update + tqdm_lock.release() # unlock + # return whether the upload was successful + return success + + with ThreadPoolExecutor(max_workers=max_workers) as executor: + results = list(executor.map( + lambda_, [i for i in range(len(dataset))], chunksize=1)) + + if not all(results): + msg = 'Warning: Unsuccessful upload(s)! ' + msg += 'This could cause problems when uploading embeddings.' + msg += 'Failed at image: {}'.format(results.index(False)) + warnings.warn(msg) # set image type of data and create initial tag if mode == 'full': @@ -313,6 +361,48 @@ def upload_images_from_folder(path_to_folder: str, routes.users.datasets.tags.post(dataset_id, token) +def upload_images_from_folder(path_to_folder: str, + dataset_id: str, + token: str, + max_workers: int = 8 , + mode: str = 'thumbnails', + verbose: bool = True): + """Uploads images from a directory to the Lightly cloud solution. + + Args: + path_to_folder: + Path to the folder which holds the input images. + dataset_id: + The unique identifier for the dataset. + token: + Token for authentication. + max_workers: + Maximum number of workers uploading images in parallel. + max_requests: + Maximum number of requests a single worker can do before he has + to wait for the others. + mode: + One of [full, thumbnails, metadata]. Whether to upload thumbnails, + full images, or metadata only. + + Raises: + ValueError if dataset is too large. + RuntimeError if the connection to the server failed. + RuntimeError if dataset already has an initial tag. + + """ + + dataset = LightlyDataset(from_folder=path_to_folder) + upload_dataset( + dataset, + dataset_id, + token, + max_workers=max_workers, + mode=mode, + verbose=verbose, + ) + + def _upload_metadata_from_json(path_to_embeddings: str, dataset_id: str, token: str): diff --git a/lightly/api/utils.py b/lightly/api/utils.py index 67f9a5834..b6bed2bc7 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -113,7 +113,7 @@ def size_in_bytes(img): """ img_file = io.BytesIO() - img.save(img_file, format=img.format) + img.save(img_file, format='png') return img_file.tell() @@ -196,7 +196,14 @@ def resize_image(image, max_width: int, max_height: int): return new_image -def check_image(filename: str): +def check_filename(basename): + """ + + """ + return len(basename) <= MAXIMUM_FILENAME_LENGTH + + +def check_image(image): """Checks whether an image is corrupted or not. The function reports the metadata, and opens the file to check whether @@ -206,41 +213,27 @@ def check_image(filename: str): filename: (str) Path to the file. Returns: - A dictionary of metadata and a flag whether file is corrupted. + A dictionary of metadata of the image. """ - basename = os.path.basename(filename) - if len(basename) > MAXIMUM_FILENAME_LENGTH: - msg = f'Filename {basename} is longer than the allowed maximum of' - msg += f'{MAXIMUM_FILENAME_LENGTH} character and will be skipped.' - warnings.warn(msg) - return {}, True - - is_corrupted = False - corruption = '' - image = Image.open(filename) + # try to load the image to see whether it's corrupted or not try: image.load() + is_corrupted = False + corruption = '' except IOError as e: is_corrupted = True corruption = e - if not is_corrupted and image.format.lower() not in LEGAL_IMAGE_FORMATS: - is_corrupted = True - corruption = f'Illegal image format {image.format}.' - image.close() + # calculate metadata from image if is_corrupted: - metadata = { - 'corruption': corruption - } + metadata = { 'corruption': corruption } else: - image = Image.open(filename, 'r') metadata = get_meta_from_img(image) metadata['corruption'] = '' - image.close() - metadata['is_corrupted'] = is_corrupted - return metadata, is_corrupted + metadata['is_corrupted'] = is_corrupted + return metadata def post_request(dst_url, data=None, json=None, diff --git a/lightly/cli/download_cli.py b/lightly/cli/download_cli.py index 2bddc67de..def21b31f 100644 --- a/lightly/cli/download_cli.py +++ b/lightly/cli/download_cli.py @@ -53,6 +53,18 @@ def _download_cli(cfg, is_cli_call=True): print(msg) if cfg['input_dir'] and cfg['output_dir']: + + input_dir = fix_input_path(cfg['input_dir']) + output_dir = fix_input_path(cfg['output_dir']) + print(f'Copying files from {input_dir} to {output_dir}.') + + # + dataset = data.LightlyDataset(from_folder=input_dir) + + # + dataset.dump(output_dir, samples) + + """ # "name.jpg" -> "/name.jpg" to prevent bugs like this: # "path/to/1234.jpg" ends with both "234.jpg" and "1234.jpg" samples = [os.path.join(' ', s)[1:] for s in samples] @@ -71,11 +83,12 @@ def _download_cli(cfg, is_cli_call=True): indices = [i for i in range(len(source_names)) if any([source_names[i].endswith(s) for s in samples])] - print(f'Copying files from {input_dir} to {output_dir}.') + for i in tqdm(indices): dirname = os.path.dirname(target_names[i]) os.makedirs(dirname, exist_ok=True) shutil.copy(source_names[i], target_names[i]) + """ @hydra.main(config_path='config', config_name='config') From 3120900df7a6cdfc94af866a0763e6bca2e761cf Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 14:55:50 +0000 Subject: [PATCH 06/14] Add additional requirements for dev/video --- requirements/dev.txt | 3 ++- requirements/video.txt | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) create mode 100644 requirements/video.txt diff --git a/requirements/dev.txt b/requirements/dev.txt index 7a9ea9d04..8a4c57743 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -7,4 +7,5 @@ responses sphinx-gallery sphinx_rtd_theme matplotlib -pre-commit \ No newline at end of file +pre-commit +opencv-python \ No newline at end of file diff --git a/requirements/video.txt b/requirements/video.txt new file mode 100644 index 000000000..f805cb2c2 --- /dev/null +++ b/requirements/video.txt @@ -0,0 +1,2 @@ +torchvision>=0.8.0 +av>=8.0.2 \ No newline at end of file From 3fd880f84e28831a897e5f9b9f5bdc2bf513ec85 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 14:56:49 +0000 Subject: [PATCH 07/14] Add tests for video datasets --- setup.py | 9 +++- tests/core/test_Core.py | 3 +- tests/data/test_LightlyDataset.py | 63 +++++++++++++++++++++++++++ tests/data/test_VideoDataset.py | 71 +++++++++++++++++++++++++++++++ tox.ini | 21 ++++++++- 5 files changed, 164 insertions(+), 3 deletions(-) create mode 100644 tests/data/test_VideoDataset.py diff --git a/setup.py b/setup.py index ec51be106..e1e69509b 100644 --- a/setup.py +++ b/setup.py @@ -65,7 +65,14 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') python_requires = '>=3.6' install_requires = load_requirements() + video_requires = load_requirements(filename='video.txt') dev_requires = load_requirements(filename='dev.txt') + all_requires = dev_requires + video_requires + extras_require = { + 'video': video_requires, + 'dev': dev_requires, + 'all': all_requires, + } packages = [ 'lightly', @@ -126,7 +133,7 @@ def load_requirements(path_dir=PATH_ROOT, filename='base.txt', comment_char='#') long_description=long_description, long_description_content_type='text/markdown', install_requires=install_requires, - extras_require={'dev': dev_requires}, + extras_require=extras_require, python_requires=python_requires, packages=packages, classifiers=classifiers, diff --git a/tests/core/test_Core.py b/tests/core/test_Core.py index 2d58f9a64..5c1e5ccd6 100644 --- a/tests/core/test_Core.py +++ b/tests/core/test_Core.py @@ -53,7 +53,8 @@ def test_train_and_embed(self): trainer = { 'max_epochs': 1 } - train_model_and_embed_images(input_dir=dataset_dir, trainer=trainer) + train_model_and_embed_images( + input_dir=dataset_dir, trainer=trainer) shutil.rmtree(dataset_dir) pattern = 'lightly_epoch(.*)?.ckpt$' for root, dirs, files in os.walk(os.getcwd()): diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py index 53184eaf9..97dac249c 100644 --- a/tests/data/test_LightlyDataset.py +++ b/tests/data/test_LightlyDataset.py @@ -3,8 +3,17 @@ import shutil import torchvision import tempfile +import warnings +import numpy as np from lightly.data import LightlyDataset +try: + from lightly.data._video import VideoDataset + import cv2 + VIDEO_DATASET_AVAILABLE = True +except Exception: + VIDEO_DATASET_AVAILABLE = False + class TestLightlyDataset(unittest.TestCase): @@ -46,6 +55,24 @@ def create_dataset(self, n_subfolders=5, n_samples_per_subfolder=20): sample_names[sample_idx])) return tmp_dir, folder_names, sample_names + def create_video_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): + + self.n_videos = n_videos + self.n_frames_per_video = n_frames_per_video + + self.input_dir = tempfile.mkdtemp() + self.ensure_dir(self.input_dir) + self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype(np.uint8) + self.extensions = ('.avi') + + for i in range(5): + path = os.path.join(self.input_dir, f'output-{i}.avi') + print(path) + out = cv2.VideoWriter(path, 0, 1, (w, h)) + for frame in self.frames: + out.write(frame) + out.release() + def test_create_lightly_dataset_from_folder(self): n_subfolders = 5 n_samples_per_subfolder = 10 @@ -68,7 +95,15 @@ def test_create_lightly_dataset_from_folder(self): self.assertEqual(len(dataset), n_tot_files) self.assertListEqual(sorted(fnames), sorted(filenames)) + out_dir = tempfile.mkdtemp() + dataset.dump(out_dir) + self.assertEqual( + sum(len(os.listdir(os.path.join(out_dir, subdir))) for subdir in os.listdir(out_dir)), + len(dataset), + ) + shutil.rmtree(dataset_dir) + shutil.rmtree(out_dir) def test_create_lightly_dataset_from_folder_nosubdir(self): @@ -103,6 +138,8 @@ def test_create_lightly_dataset_from_torchvision(self): for dataset_name in self.available_dataset_names: dataset = LightlyDataset(root=tmp_dir, name=dataset_name) self.assertIsNotNone(dataset) + + shutil.rmtree(tmp_dir) def test_not_existing_torchvision_dataset(self): list_of_non_existing_names = [ @@ -120,3 +157,29 @@ def test_not_existing_folder_dataset(self): LightlyDataset( from_folder='/a-random-hopefully-non/existing-path-to-nowhere/' ) + + def test_video_dataset(self): + + if not VIDEO_DATASET_AVAILABLE: + tmp_dir = tempfile.mkdtemp() + self.ensure_dir(tmp_dir) + # simulate a video + path = os.path.join(tmp_dir, 'my_file.png') + dataset = torchvision.datasets.FakeData(size=1, image_size=(3, 32, 32)) + image, _ = dataset[0] + image.save(path) + os.rename(path, os.path.join(tmp_dir, 'my_file.avi')) + with self.assertRaises(ValueError): + dataset = LightlyDataset(from_folder=tmp_dir) + + warnings.warn( + 'Did not test video dataset because of missing requirements') + shutil.rmtree(tmp_dir) + return + + self.create_video_dataset() + dataset = LightlyDataset(from_folder=self.input_dir) + + out_dir = tempfile.mkdtemp() + dataset.dump(out_dir) + self.assertEqual(len(os.listdir(out_dir)), len(dataset)) diff --git a/tests/data/test_VideoDataset.py b/tests/data/test_VideoDataset.py new file mode 100644 index 000000000..8108a6697 --- /dev/null +++ b/tests/data/test_VideoDataset.py @@ -0,0 +1,71 @@ +import unittest +import os +import shutil +import numpy as np +import tempfile +import warnings +import PIL + +try: + from lightly.data._video import VideoDataset + import cv2 + VIDEO_DATASET_AVAILABLE = True +except Exception: + VIDEO_DATASET_AVAILABLE = False + +class TestVideoDataset(unittest.TestCase): + + def ensure_dir(self, path_to_folder: str): + if not os.path.exists(path_to_folder): + os.makedirs(path_to_folder) + + def create_dataset(self, n_videos=5, n_frames_per_video=10, w=32, h=32, c=3): + + self.n_videos = n_videos + self.n_frames_per_video = n_frames_per_video + + self.input_dir = tempfile.mkdtemp() + self.ensure_dir(self.input_dir) + self.frames = (np.random.randn(n_frames_per_video, w, h, c) * 255).astype(np.uint8) + self.extensions = ('.avi') + + for i in range(5): + path = os.path.join(self.input_dir, f'output-{i}.avi') + print(path) + out = cv2.VideoWriter(path, 0, 1, (w, h)) + for frame in self.frames: + out.write(frame) + out.release() + + def test_video_dataset_from_folder(self): + + if not VIDEO_DATASET_AVAILABLE: + warnings.warn( + 'Did not test video dataset because of missing requirements') + return + + self.create_dataset() + + # create dataset + dataset = VideoDataset(self.input_dir, extensions=self.extensions) + + # __len__ + self.assertEqual(len(dataset), self.n_frames_per_video * self.n_videos) + + # __getitem__ + for i in range(len(dataset)): + frame, label = dataset[i] + self.assertIsInstance(frame, PIL.Image.Image) + self.assertEqual(label, i // self.n_frames_per_video) + + # get_filename + for i in range(len(dataset)): + frame, label = dataset[i] + filename = dataset.get_filename(i) + self.assertTrue( + filename.endswith( + f"-{float(i % self.n_frames_per_video):.8f}s-avi.png" + ) + ) + + shutil.rmtree(self.input_dir) diff --git a/tox.ini b/tox.ini index ac139e26c..60926c55c 100644 --- a/tox.ini +++ b/tox.ini @@ -18,7 +18,7 @@ # Requirement already satisfied: setuptools in /usr/local/lib/python3.6/site-packages (40.6.2) # Requirement already satisfied: pip in /usr/local/lib/python3.6/site-packages (18.1) [tox] -envlist = cuda, cpu, cpu-minimal +envlist = cuda, cpu, cpu-minimal, video # we install the package manually later skipsdist=True @@ -82,3 +82,22 @@ commands = pip install .[dev] echo "Running cpu-minimal test" make test + +[testenv:video] +# test the full package on the cpu with minimal configuration +basepython = python3.7 + +# suppress warnings +whitelist_externals = make + pip + echo + +passenv = * +setenv = LIGHTLY_SERVER_LOCATION = https://api-dev.lightly.ai + CUDA_VISIBLE_DEVICES = -1 + +commands = + pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 -f https://download.pytorch.org/whl/torch_stable.html + pip install .[all] + echo "Running video test" + make test From 6393196317778c19b87d379e2e0278627798bf03 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Wed, 18 Nov 2020 16:56:24 +0000 Subject: [PATCH 08/14] Update docstrings and minor formatting --- lightly/api/upload.py | 36 ++--------- lightly/api/utils.py | 9 ++- lightly/cli/download_cli.py | 33 ++-------- lightly/data/_helpers.py | 48 +++++++++------ lightly/data/_image.py | 85 +++++++++++++++---------- lightly/data/_video.py | 120 +++++++++++++++++++++++++++++++----- lightly/data/dataset.py | 44 ++++++++----- 7 files changed, 230 insertions(+), 145 deletions(-) diff --git a/lightly/api/upload.py b/lightly/api/upload.py index 55eb966f9..c88d7c7ec 100644 --- a/lightly/api/upload.py +++ b/lightly/api/upload.py @@ -232,14 +232,14 @@ def _upload_single_image(image, def upload_dataset(dataset: LightlyDataset, dataset_id: str, token: str, - max_workers: int = 8 , + max_workers: int = 8, mode: str = 'thumbnails', verbose: bool = True): """Uploads images from a directory to the Lightly cloud solution. Args: dataset - The dataset to upload + The dataset to upload. dataset_id: The unique identifier for the dataset. token: @@ -284,35 +284,6 @@ def upload_dataset(dataset: LightlyDataset, # handle the case where len(dataset) < max_workers max_workers = min(len(dataset), max_workers) - # define lambda function for parallel upload - """ - def lambda_(index): - # wait for a random amount of time to prevent server overload - time.sleep(index * 0.1) - # find batch start and end - start = index * batch_size - stop = min((index + 1) * batch_size, len(dataset)) - # progress bar - pbar = tqdm.tqdm(unit='imgs', total=stop - start) - # upload all images in the batch - success_list = [] - for i in range(start, stop): - time.sleep(random.random() * 0.01) - image, label, filename = dataset[i] - success = _upload_single_image( - image=image, - label=label, - filename=filename, - dataset_id=dataset_id, - token=token, - mode=mode, - ) - success_list.append(success) - pbar.update(1) - pbar.refresh() - - return all(success_list) - """ # upload the samples if verbose: print(f'Uploading images (with {max_workers} workers).', flush=True) @@ -320,6 +291,7 @@ def lambda_(index): pbar = tqdm.tqdm(unit='imgs', total=len(dataset)) tqdm_lock = tqdm.tqdm.get_lock() + # define lambda function for concurrent upload def lambda_(i): # load image image, label, filename = dataset[i] @@ -364,7 +336,7 @@ def lambda_(i): def upload_images_from_folder(path_to_folder: str, dataset_id: str, token: str, - max_workers: int = 8 , + max_workers: int = 8, mode: str = 'thumbnails', verbose: bool = True): """Uploads images from a directory to the Lightly cloud solution. diff --git a/lightly/api/utils.py b/lightly/api/utils.py index b6bed2bc7..92ce34434 100644 --- a/lightly/api/utils.py +++ b/lightly/api/utils.py @@ -197,7 +197,11 @@ def resize_image(image, max_width: int, max_height: int): def check_filename(basename): - """ + """Checks the length of the filename. + + Args: + basename: + Basename of the file. """ return len(basename) <= MAXIMUM_FILENAME_LENGTH @@ -210,7 +214,8 @@ def check_image(image): it is corrupt or not. Args: - filename: (str) Path to the file. + image: + PIL image from which metadata will be computed. Returns: A dictionary of metadata of the image. diff --git a/lightly/cli/download_cli.py b/lightly/cli/download_cli.py index def21b31f..20fcaff81 100644 --- a/lightly/cli/download_cli.py +++ b/lightly/cli/download_cli.py @@ -48,9 +48,10 @@ def _download_cli(cfg, is_cli_call=True): with open(cfg['tag_name'] + '.txt', 'w') as f: for item in samples: f.write("%s\n" % item) + msg = 'The list of files in tag {} is stored at: '.format(cfg['tag_name']) msg += os.path.join(os.getcwd(), cfg['tag_name'] + '.txt') - print(msg) + print(msg, flush=True) if cfg['input_dir'] and cfg['output_dir']: @@ -58,38 +59,12 @@ def _download_cli(cfg, is_cli_call=True): output_dir = fix_input_path(cfg['output_dir']) print(f'Copying files from {input_dir} to {output_dir}.') - # + # create a dataset from the input directory dataset = data.LightlyDataset(from_folder=input_dir) - # + # dump the dataset in the output directory dataset.dump(output_dir, samples) - """ - # "name.jpg" -> "/name.jpg" to prevent bugs like this: - # "path/to/1234.jpg" ends with both "234.jpg" and "1234.jpg" - samples = [os.path.join(' ', s)[1:] for s in samples] - - # copy all images from one folder to the other - input_dir = fix_input_path(cfg['input_dir']) - output_dir = fix_input_path(cfg['output_dir']) - - dataset = data.LightlyDataset(from_folder=input_dir) - basenames = dataset.get_filenames() - - source_names = [os.path.join(input_dir, f) for f in basenames] - target_names = [os.path.join(output_dir, f) for f in basenames] - - # only copy files which are in the tag - indices = [i for i in range(len(source_names)) - if any([source_names[i].endswith(s) for s in samples])] - - - for i in tqdm(indices): - dirname = os.path.dirname(target_names[i]) - os.makedirs(dirname, exist_ok=True) - shutil.copy(source_names[i], target_names[i]) - """ - @hydra.main(config_path='config', config_name='config') def download_cli(cfg): diff --git a/lightly/data/_helpers.py b/lightly/data/_helpers.py index e5e4afa70..07a4cba4d 100644 --- a/lightly/data/_helpers.py +++ b/lightly/data/_helpers.py @@ -15,17 +15,20 @@ VIDEO_DATASET_AVAILABLE = False -from lightly.data._image_loaders import default_loader - IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp') VIDEO_EXTENSIONS = ('.mp4', '.mov', '.avi') -def _contains_videos(root, extensions): - """ +def _contains_videos(root: str, extensions: tuple): + """Checks whether directory contains video files. + + Args: + root: Root directory path. + Returns: + True if root contains subdirectories else false. """ list_dir = os.listdir(root) is_video = \ @@ -33,14 +36,14 @@ def _contains_videos(root, extensions): return any(is_video) -def _contains_subdirs(root): - """Check whether directory contains subdirectories +def _contains_subdirs(root: str): + """Checks whether directory contains subdirectories. Args: - root: (str) Root directory path + root: Root directory path. Returns: - True if root contains subdirectories else false + True if root contains subdirectories else false. """ list_dir = os.listdir(root) @@ -49,45 +52,52 @@ def _contains_subdirs(root): return any(is_dir) -def _load_dataset_from_folder(root, transform): - """Initialize dataset from folder +def _load_dataset_from_folder(root: str, transform): + """Initializes dataset from folder. Args: root: (str) Root directory path transform: (torchvision.transforms.Compose) image transformations Returns: - Dataset consisting of images in the root directory + Dataset consisting of images in the root directory. + """ + # if there is a video in the input directory but we do not have + # the right dependencies, raise a ValueError contains_videos = _contains_videos(root, VIDEO_EXTENSIONS) if contains_videos and not VIDEO_DATASET_AVAILABLE: raise ValueError(f'The input directory {root} contains videos ' - 'but the VideoDataset is not available.') + 'but the VideoDataset is not available. \n' + 'Make sure you have installed the right ' + 'dependencies.') if contains_videos: + # root contains videos -> create a video dataset dataset = VideoDataset(root, extensions=VIDEO_EXTENSIONS, transform=transform) elif _contains_subdirs(root): + # root contains subdirectories -> create an image folder dataset dataset = datasets.ImageFolder(root, transform=transform) else: + # root contains plain images -> create a folder dataset dataset = DatasetFolder(root, - default_loader, extensions=IMG_EXTENSIONS, transform=transform) return dataset -def _load_dataset(root='', - name='cifar10', - train=True, - download=True, +def _load_dataset(root: str = '', + name: str = 'cifar10', + train: bool = True, + download: bool = True, transform=None, - from_folder=''): - """ Initialize dataset from torchvision or from folder + from_folder: str = ''): + """Initializes dataset from torchvision or from folder. Args: root: (str) Directory where dataset is stored diff --git a/lightly/data/_image.py b/lightly/data/_image.py index 567614e85..e37479242 100644 --- a/lightly/data/_image.py +++ b/lightly/data/_image.py @@ -1,23 +1,27 @@ -""" """ +""" Image Dataset """ -# -# +# Copyright (c) 2020. Lightly AG and its affiliates. +# All Rights Reserved import os import torchvision.datasets as datasets +from lightly.data._image_loaders import default_loader + def _make_dataset(directory, extensions=None, is_valid_file=None): - """Return a list of all image files with targets in the directory + """Returns a list of all image files with targets in the directory. Args: - directory: (str) Root directory path - (should not contain subdirectories!) - extensions: (List[str]) List of allowed extensions - is_valid_file: (callable) Used to check corrupt files + directory: + Root directory path (should not contain subdirectories!). + extensions: + Tuple of valid extensions. + is_valid_file: + Used to find valid files. Returns: - List of instance tuples: (path_i, target_i = 0) + List of instance tuples: (path_i, target_i = 0). """ @@ -39,27 +43,40 @@ def _is_valid_file(filename): class DatasetFolder(datasets.VisionDataset): + """Implements a dataset folder. + + DatasetFolder based on torchvisions implementation. + (https://pytorch.org/docs/stable/torchvision/datasets.html#datasetfolder) + + Attributes: + root: + Root directory path + loader: + Function that loads file at path + extensions: + Tuple of allowed extensions + transform: + Function that takes a PIL image and returns transformed version + target_transform: + As transform but for targets + is_valid_file: + Used to check corrupt files + + Raises: + RuntimeError: If no supported files are found in root. - def __init__(self, root, loader, extensions=None, transform=None, - target_transform=None, is_valid_file=None): - """Constructor based on torchvisions DatasetFolder - (https://pytorch.org/docs/stable/torchvision/datasets.html#datasetfolder) - - Args: - root: (str) Root directory path - loader: (callable) Function that loads file at path - extensions: (List[str]) List of allowed extensions - transform: Function that takes a PIL image and returns - transformed version - target_transform: As transform but for targets - is_valid_file: (callable) Used to check corrupt files - - Raises: - RuntimeError: If no supported files are found in root. + """ - """ + def __init__(self, + root: str, + loader=default_loader, + extensions=None, + transform=None, + target_transform=None, + is_valid_file=None): - super(DatasetFolder, self).__init__(root, transform=transform, + super(DatasetFolder, self).__init__(root, + transform=transform, target_transform=target_transform) samples = _make_dataset(self.root, extensions, is_valid_file) @@ -76,14 +93,15 @@ def __init__(self, root, loader, extensions=None, transform=None, self.samples = samples self.targets = [s[1] for s in samples] - def __getitem__(self, index): - """Returns item at index + def __getitem__(self, index: int): + """Returns item at index. Args: - index: (int) Index + index: + Index of the sample to retrieve. Returns: - tuple: (sample, target) where target is 0 + A tuple (sample, target) where target is 0. """ @@ -97,4 +115,7 @@ def __getitem__(self, index): return sample, target def __len__(self): - return len(self.samples) \ No newline at end of file + """Returns the number of samples in the dataset. + + """ + return len(self.samples) diff --git a/lightly/data/_video.py b/lightly/data/_video.py index 8d105bc38..1b7b1e29f 100644 --- a/lightly/data/_video.py +++ b/lightly/data/_video.py @@ -1,7 +1,7 @@ -""" """ +""" Video Dataset """ -# -# +# Copyright (c) 2020. Lightly AG and its affiliates. +# All Rights Reserved import os import av @@ -12,10 +12,18 @@ def _video_loader(path, timestamp, pts_unit='sec'): - """ + """Reads a frame from a video at a random timestamp. + + Args: + path: + Path to the video file. + timestamp: + The timestamp at which to retrieve the frame in seconds. + pts_unit: + Unit of the timestamp. """ - # random access read from video + # random access read from video (slow) frame, _, _ = read_video(path, start_pts=timestamp, end_pts=timestamp, @@ -29,22 +37,45 @@ def _video_loader(path, timestamp, pts_unit='sec'): return image -def _make_dataset(directory, extensions=None, is_valid_file=None, pts_unit='sec'): - """ +def _make_dataset(directory, + extensions=None, + is_valid_file=None, + pts_unit='sec'): + """Returns a list of all video files, timestamps, and offsets. + + Args: + directory: + Root directory path (should not contain subdirectories). + extensions: + Tuple of valid extensions. + is_valid_file: + Used to find valid files. + pts_unit: + Unit of the timestamps. + + Returns: + A list of video files, timestamps, frame offsets, and fps. """ + # use filename to find valid files if extensions is not None: def _is_valid_file(filename): return filename.lower().endswith(extensions) - # find all instances + # overwrite function to find valid files + if is_valid_file is not None: + _is_valid_file = is_valid_file + + # find all instances (no subdirectories) instances = [] for fname in os.listdir(directory): + # skip invalid files if not _is_valid_file(fname): continue + # keep track of valid files path = os.path.join(directory, fname) instances.append(path) @@ -55,7 +86,7 @@ def _is_valid_file(filename): timestamps.append(ts) fpss.append(fps) - # get offsets + # get frame offsets offsets = [len(ts) for ts in timestamps] offsets = [0] + offsets[:-1] for i in range(1, len(offsets)): @@ -65,18 +96,48 @@ def _is_valid_file(filename): class VideoDataset(datasets.VisionDataset): - """ + """Implementation of a video dataset. + + The VideoDataset allows random reads from a video file without extracting + all frames beforehand. This is more storage efficient but is slower. + + Attributes: + root: + Root directory path. + loader: + Function that loads file at path. + extensions: + Tuple of allowed extensions. + transform: + Function that takes a PIL image and returns transformed version + target_transform: + As transform but for targets + is_valid_file: + Used to check corrupt files """ - def __init__(self, root, loader=_video_loader, extensions=None, - transform=None, target_transform=None, is_valid_file=None): + def __init__(self, + root, + loader=_video_loader, + extensions=None, + transform=None, + target_transform=None, + is_valid_file=None): - super(VideoDataset, self).__init__(root, transform=transform, + super(VideoDataset, self).__init__(root, + transform=transform, target_transform=target_transform) videos, video_timestamps, offsets, fpss = _make_dataset( self.root, extensions, is_valid_file) + + if len(videos) == 0: + msg = 'Found 0 videos in folder: {}\n'.format(self.root) + if extensions is not None: + msg += 'Supported extensions are: {}'.format( + ','.join(extensions)) + raise RuntimeError(msg) self.extensions = extensions self.loader = loader @@ -87,7 +148,21 @@ def __init__(self, root, loader=_video_loader, extensions=None, self.fpss = fpss def __getitem__(self, index): - """ + """Returns item at index. + + Finds the video of the frame at index with the help of the frame + offsets. Then, loads the frame from the video, applies the transforms, + and returns the frame along with the index of the video (as target). + + Args: + index: + Index of the sample to retrieve. + + Returns: + A tuple (sample, target) where target indicates the video index. + + Raises: + IndexError if index is out of bounds. """ if index < 0 or index >= self.__len__(): @@ -111,14 +186,27 @@ def __getitem__(self, index): return sample, target def __len__(self): - """ + """Returns the number of samples in the dataset. """ return sum((len(ts) for ts in self.video_timestamps)) def get_filename(self, index): - """ + """Returns a filename for the frame at index. + + The filename is created from the video filename, the timestamp, and + the video format. E.g. when retrieving a sample from the video + `my_video.mp4` at time 0.5s, the filename will be: + + >>> my_video-0.50000000s-mp4.png + + Args: + index: + Index of the frame to retrieve. + Returns: + The filename of the frame as described above. + """ if index < 0 or index >= self.__len__(): raise IndexError(f'Index {index} is out of bounds for VideoDataset' diff --git a/lightly/data/dataset.py b/lightly/data/dataset.py index 6c608eecc..c19b32fb5 100644 --- a/lightly/data/dataset.py +++ b/lightly/data/dataset.py @@ -63,14 +63,6 @@ def __init__(self, download: bool = True, from_folder: str = '', transform=None): - """ Constructor - - - - Raises: - ValueError: If the specified dataset doesn't exist - - """ super(LightlyDataset, self).__init__() self.dataset = _load_dataset( @@ -80,8 +72,23 @@ def __init__(self, if from_folder: self.root_folder = from_folder - def dump_image(self, output_dir: str, filename: str, format: Union[str, None] = None): - """ + def dump_image(self, + output_dir: str, + filename: str, + format: Union[str, None] = None): + """Saves a single image to the output directory. + + Will copy the image from the input directory to the output directory + if possible. If not (e.g. for VideoDatasets), will load the image and + then save it to the output directory with the specified format. + + Args: + output_dir: + Output directory where the image is stored. + filename: + Filename of the image to store. + format: + Image format. """ index = self.get_filenames().index(filename) @@ -106,16 +113,23 @@ def dump_image(self, output_dir: str, filename: str, format: Union[str, None] = # could not determine format from filename image.save(os.path.join(output_dir, filename), format='png') - def dump(self, output_dir: str, filenames: Union[List[str], None] = None, format: Union[str, None] = None): - """Saves all specified images to the output directory. + def dump(self, + output_dir: str, + filenames: Union[List[str], None] = None, + format: Union[str, None] = None): + """Saves images to the output directory. + + Will copy the images from the input directory to the output directory + if possible. If not (e.g. for VideoDatasets), will load the images and + then save them to the output directory with the specified format. Args: output_dir: - TODO + Output directory where the image is stored. filenames: - TODO + Filenames of the images to store. If None, stores all images. format: - TODO + Image format. """ # make sure no transforms are applied to the images From e78e1f79eaa07d03646072b9fe1e67b43cf7890d Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Thu, 19 Nov 2020 11:08:30 +0000 Subject: [PATCH 09/14] Move checkpoint loading to models --- lightly/cli/_helpers.py | 12 --------- lightly/cli/embed_cli.py | 15 +++++------ lightly/cli/train_cli.py | 17 ++++++------- lightly/models/_helpers.py | 13 ++++++++++ lightly/models/moco.py | 51 ++++++++++++++++++++++++++++++++++++++ lightly/models/simclr.py | 47 +++++++++++++++++++++++++++++++++++ 6 files changed, 125 insertions(+), 30 deletions(-) create mode 100644 lightly/models/_helpers.py diff --git a/lightly/cli/_helpers.py b/lightly/cli/_helpers.py index 4b7c86748..5724b8616 100644 --- a/lightly/cli/_helpers.py +++ b/lightly/cli/_helpers.py @@ -19,18 +19,6 @@ def fix_input_path(path): return path -def filter_state_dict(state_dict): - """Prevent unexpected key error when loading PyTorch-Lightning checkpoints - by removing the unnecessary prefix model. from each key. - - """ - new_state_dict = {} - for key, item in state_dict.items(): - new_key = '.'.join(key.split('.')[1:]) - new_state_dict[new_key] = item - return new_state_dict - - def is_url(checkpoint): """Check whether the checkpoint is a url or not. diff --git a/lightly/cli/embed_cli.py b/lightly/cli/embed_cli.py index 0c8ad9a90..e4c8a5628 100644 --- a/lightly/cli/embed_cli.py +++ b/lightly/cli/embed_cli.py @@ -21,7 +21,6 @@ from lightly.cli._helpers import get_ptmodel_from_config from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import filter_state_dict from lightly.cli._helpers import load_state_dict_from_url @@ -48,8 +47,6 @@ def _embed_cli(cfg, is_cli_call=True): else: device = torch.device('cpu') - model = ResNetSimCLR(**cfg['model']).to(device) - transform = torchvision.transforms.Compose([ torchvision.transforms.Resize((cfg['collate']['input_size'], cfg['collate']['input_size'])), @@ -71,10 +68,8 @@ def _embed_cli(cfg, is_cli_call=True): ) dataloader = torch.utils.data.DataLoader(dataset, **cfg['loader']) - # load the PyTorch state dictionary and map it to the current device - # then, remove the model. prefix which is caused by the pytorch-lightning - # checkpoint saver and load the model from the "filtered" state dict - # this approach is compatible with pytorch_lightning 0.7.1 - 0.8.4 (latest) + # load the PyTorch state dictionary and map it to the current device + state_dict = None if not checkpoint: checkpoint, key = get_ptmodel_from_config(cfg['model']) if not checkpoint: @@ -91,8 +86,10 @@ def _embed_cli(cfg, is_cli_call=True): )['state_dict'] if state_dict is not None: - state_dict = filter_state_dict(state_dict) - model.load_state_dict(state_dict) + model = ResNetSimCLR.from_state_dict(state_dict, **cfg['model']) + model = model.to(device) + else: + model = ResNetSimCLR(**cfg['model']).to(device) encoder = SelfSupervisedEmbedding(model, None, None, None) embeddings, labels, filenames = encoder.embed(dataloader, device=device) diff --git a/lightly/cli/train_cli.py b/lightly/cli/train_cli.py index 3fbd43129..99a352a44 100644 --- a/lightly/cli/train_cli.py +++ b/lightly/cli/train_cli.py @@ -16,12 +16,11 @@ from lightly.data import LightlyDataset from lightly.embedding import SelfSupervisedEmbedding from lightly.loss import NTXentLoss -from lightly.models import ResNetSimCLR +from lightly.models import ResNetSimCLR, ResNetMoCo from lightly.cli._helpers import is_url from lightly.cli._helpers import get_ptmodel_from_config from lightly.cli._helpers import fix_input_path -from lightly.cli._helpers import filter_state_dict from lightly.cli._helpers import load_state_dict_from_url @@ -59,7 +58,7 @@ def _train_cli(cfg, is_cli_call=True): msg += 'loader.batch_size=BSZ' warnings.warn(msg) - model = ResNetSimCLR(**cfg['model']) + state_dict = None checkpoint = cfg['checkpoint'] if cfg['pre_trained'] and not checkpoint: # if checkpoint wasn't specified explicitly and pre_trained is True @@ -75,9 +74,6 @@ def _train_cli(cfg, is_cli_call=True): if checkpoint: # load the PyTorch state dictionary and map it to the current device - # then, remove the model. prefix which is caused by the pytorch-lightning - # checkpoint saver and load the model from the "filtered" state dict - # this approach is compatible with pytorch_lightning 0.7.1 - 0.8.4 (latest) if is_url(checkpoint): state_dict = load_state_dict_from_url( checkpoint, map_location=device @@ -86,9 +82,12 @@ def _train_cli(cfg, is_cli_call=True): state_dict = torch.load( checkpoint, map_location=device )['state_dict'] - if state_dict is not None: - state_dict = filter_state_dict(state_dict) - model.load_state_dict(state_dict) + + # load model + if state_dict is not None: + model = ResNetSimCLR.from_state_dict(state_dict, **cfg['model']) + else: + model = ResNetSimCLR(**cfg['model']) criterion = NTXentLoss(**cfg['criterion']) optimizer = torch.optim.SGD(model.parameters(), **cfg['optimizer']) diff --git a/lightly/models/_helpers.py b/lightly/models/_helpers.py new file mode 100644 index 000000000..ce27a934c --- /dev/null +++ b/lightly/models/_helpers.py @@ -0,0 +1,13 @@ +from difflib import ndiff + + +def filter_state_dict(state_dict): + """Prevent unexpected key error when loading PyTorch-Lightning checkpoints + by removing the unnecessary prefix model. from each key. + + """ + new_state_dict = {} + for key, item in state_dict.items(): + new_key = '.'.join(key.split('.')[1:]) + new_state_dict[new_key] = item + return new_state_dict \ No newline at end of file diff --git a/lightly/models/moco.py b/lightly/models/moco.py index 496bcebea..e1371f383 100644 --- a/lightly/models/moco.py +++ b/lightly/models/moco.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn from lightly.models.resnet import ResNetGenerator +from lightly.models._helpers import filter_state_dict def _get_features_and_projections(resnet, num_ftrs, out_dim): @@ -83,6 +84,56 @@ def __init__(self, param_k.requires_grad = False self._momentum_update_key_encoder(0.) + @classmethod + def from_state_dict(cls, + state_dict: dict, + name: str = 'resnet-18', + width: float = 1., + num_ftrs: int = 32, + out_dim: int = 128, + m: float = 0.999, + strict: bool = True, + apply_filter: bool = True): + """Initializes a ResNetMoCo and loads weights from a checkpoint. + + Args: + state_dict: + State dictionary with layer weights. + name: + ResNet version, choose from resnet-{9, 18, 34, 50, 101, 152}. + width: + Width of the ResNet. + num_ftrs: + Dimension of the embedding (before the projection head). + out_dim: + Dimension of the output (after the projection head). + m: + Momentum for momentum update of the key-encoder. + strict: + Set to False when loading from a partial state_dict. + apply_filter: + If True, removes the `model.` prefix from keys in the state_dict. + + """ + model = cls( + name=name, + width=width, + num_ftrs=num_ftrs, + out_dim=out_dim, + m=m, + ) + + # remove the model. prefix which is caused by the pytorch-lightning + # checkpoint saver and load the model from the "filtered" state dict + # this approach is compatible with pytorch_lightning 0.7.1 - 0.8.4 (latest) + if apply_filter: + state_dict_ = filter_state_dict(state_dict) + else: + state_dict_ = state_dict + + model.load_state_dict(state_dict_, strict=strict) + + return model @torch.no_grad() def _momentum_update_key_encoder(self, m=0.): diff --git a/lightly/models/simclr.py b/lightly/models/simclr.py index 1c854371b..72aaec0fa 100644 --- a/lightly/models/simclr.py +++ b/lightly/models/simclr.py @@ -7,6 +7,7 @@ import torch.nn as nn from lightly.models.resnet import ResNetGenerator +from lightly.models._helpers import filter_state_dict def _get_features_and_projections(resnet, num_ftrs, out_dim): @@ -69,6 +70,52 @@ def __init__(self, self.features, self.projection_head = _get_features_and_projections( resnet, self.num_ftrs, self.out_dim) + @classmethod + def from_state_dict(cls, + state_dict: dict, + name: str = 'resnet-18', + width: float = 1., + num_ftrs: int = 32, + out_dim: int = 128, + strict: bool = True, + apply_filter: bool = True): + """Initializes a ResNetMoCo and loads weights from a checkpoint. + + Args: + state_dict: + State dictionary with layer weights. + name: + ResNet version, choose from resnet-{9, 18, 34, 50, 101, 152}. + width: + Width of the ResNet. + num_ftrs: + Dimension of the embedding (before the projection head). + out_dim: + Dimension of the output (after the projection head). + strict: + Set to False when loading from a partial state_dict. + apply_filter: + If True, removes the `model.` prefix from keys in the state_dict. + + """ + model = cls( + name=name, + width=width, + num_ftrs=num_ftrs, + out_dim=out_dim, + ) + + # remove the model. prefix which is caused by the pytorch-lightning + # checkpoint saver and load the model from the "filtered" state dict + # this approach is compatible with pytorch_lightning 0.7.1 - 0.8.4 (latest) + if apply_filter: + state_dict_ = filter_state_dict(state_dict) + else: + state_dict_ = state_dict + + model.load_state_dict(state_dict_, strict=strict) + + return model def forward(self, x: torch.Tensor): """Forward pass through ResNetSimCLR. From 3aa00c3c41b2f4a77b628c22ccc1ec2da8863f00 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Thu, 19 Nov 2020 12:38:10 +0000 Subject: [PATCH 10/14] Minor formatting changes --- lightly/cli/train_cli.py | 2 +- lightly/models/_helpers.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/lightly/cli/train_cli.py b/lightly/cli/train_cli.py index 99a352a44..42573a83e 100644 --- a/lightly/cli/train_cli.py +++ b/lightly/cli/train_cli.py @@ -16,7 +16,7 @@ from lightly.data import LightlyDataset from lightly.embedding import SelfSupervisedEmbedding from lightly.loss import NTXentLoss -from lightly.models import ResNetSimCLR, ResNetMoCo +from lightly.models import ResNetSimCLR from lightly.cli._helpers import is_url from lightly.cli._helpers import get_ptmodel_from_config diff --git a/lightly/models/_helpers.py b/lightly/models/_helpers.py index ce27a934c..6fcca977d 100644 --- a/lightly/models/_helpers.py +++ b/lightly/models/_helpers.py @@ -1,4 +1,3 @@ -from difflib import ndiff def filter_state_dict(state_dict): From 97f7c802cd8ee77f3c1084f0fd5d9c7d48a64bd4 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Thu, 19 Nov 2020 13:44:56 +0000 Subject: [PATCH 11/14] Add requested changes --- lightly/__init__.py | 1 - lightly/data/_helpers.py | 6 ++++-- lightly/data/_video.py | 25 ++++++++++++++++++++++--- tests/data/test_LightlyDataset.py | 6 +++--- 4 files changed, 29 insertions(+), 9 deletions(-) diff --git a/lightly/__init__.py b/lightly/__init__.py index c2faf4941..2191c05c3 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -44,7 +44,6 @@ def is_prefetch_generator_available(): return _prefetch_generator_available - # import core functionalities from lightly.core import train_model_and_embed_images from lightly.core import train_embedding_model diff --git a/lightly/data/_helpers.py b/lightly/data/_helpers.py index 07a4cba4d..14b74d814 100644 --- a/lightly/data/_helpers.py +++ b/lightly/data/_helpers.py @@ -11,8 +11,9 @@ try: from lightly.data._video import VideoDataset VIDEO_DATASET_AVAILABLE = True -except Exception: +except Exception as e: VIDEO_DATASET_AVAILABLE = False + VIDEO_DATASET_ERRORMSG = e IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', @@ -71,7 +72,8 @@ def _load_dataset_from_folder(root: str, transform): raise ValueError(f'The input directory {root} contains videos ' 'but the VideoDataset is not available. \n' 'Make sure you have installed the right ' - 'dependencies.') + 'dependencies. The error from the imported ' + f'module was: {VIDEO_DATASET_ERRORMSG}') if contains_videos: # root contains videos -> create a video dataset diff --git a/lightly/data/_video.py b/lightly/data/_video.py index 1b7b1e29f..e24e4b6e8 100644 --- a/lightly/data/_video.py +++ b/lightly/data/_video.py @@ -144,6 +144,8 @@ def __init__(self, self.videos = videos self.video_timestamps = video_timestamps + # offsets[i] indicates the index of the first frame of the i-th video. + # e.g. for two videos of length 10 and 20, the offsets will be [0, 10]. self.offsets = offsets self.fpss = fpss @@ -154,6 +156,19 @@ def __getitem__(self, index): offsets. Then, loads the frame from the video, applies the transforms, and returns the frame along with the index of the video (as target). + For example, if there are two videos with 10 and 20 frames respectively + in the input directory: + + Requesting the 5th sample returns the 5th frame from the first video and + the target indicates the index of the source video which is 0. + >>> dataset[5] + >>> > , 0 + + Requesting the 20th sample returns the 10th frame from the second video + and the target indicates the index of the source video which is 1. + >>> dataset[20] + >>> > , 1 + Args: index: Index of the sample to retrieve. @@ -169,15 +184,17 @@ def __getitem__(self, index): raise IndexError(f'Index {index} is out of bounds for VideoDataset' f' of size {self.__len__()}.') - # find video of the frame + # each sample belongs to a video, to load the sample at index, we need + # to find the video to which the sample belongs and then read the frame + # from this video on the disk. i = len(self.offsets) - 1 while (self.offsets[i] > index): i = i - 1 # find and return the frame as PIL image - target = i sample = self.loader(self.videos[i], self.video_timestamps[i][index - self.offsets[i]]) + target = i if self.transform is not None: sample = self.transform(sample) if self.target_transform is not None: @@ -212,7 +229,9 @@ def get_filename(self, index): raise IndexError(f'Index {index} is out of bounds for VideoDataset' f' of size {self.__len__()}.') - # find video of the frame + # each sample belongs to a video, to load the sample at index, we need + # to find the video to which the sample belongs and then read the frame + # from this video on the disk. i = len(self.offsets) - 1 while (self.offsets[i] > index): i = i - 1 diff --git a/tests/data/test_LightlyDataset.py b/tests/data/test_LightlyDataset.py index 97dac249c..e9d7d89dd 100644 --- a/tests/data/test_LightlyDataset.py +++ b/tests/data/test_LightlyDataset.py @@ -18,8 +18,7 @@ class TestLightlyDataset(unittest.TestCase): def ensure_dir(self, path_to_folder: str): - if not os.path.exists(path_to_folder): - os.makedirs(path_to_folder) + os.makedirs(path_to_folder, exist_ok=True) def setUp(self): self.available_dataset_names = ['cifar10', @@ -162,8 +161,9 @@ def test_video_dataset(self): if not VIDEO_DATASET_AVAILABLE: tmp_dir = tempfile.mkdtemp() - self.ensure_dir(tmp_dir) # simulate a video + # the video dataset will check to see whether there exists a file + # with a video extension, it's enough to fake a video file here path = os.path.join(tmp_dir, 'my_file.png') dataset = torchvision.datasets.FakeData(size=1, image_size=(3, 32, 32)) image, _ = dataset[0] From 256343d4a9e860880a5e4855bec1ce846f8a5472 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Fri, 20 Nov 2020 08:21:45 +0000 Subject: [PATCH 12/14] Add video folder datasets to input structure --- .../source/tutorials/structure_your_input.rst | 33 +++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/source/tutorials/structure_your_input.rst b/docs/source/tutorials/structure_your_input.rst index 90859fc44..67e674642 100644 --- a/docs/source/tutorials/structure_your_input.rst +++ b/docs/source/tutorials/structure_your_input.rst @@ -106,6 +106,39 @@ For the structure above, lightly will understand the input as follows: 10, ] +Video Folder Datasets +--------------------- +The lightly Python package allows you to work `directly` on video data, without having +to exctract the frames first. This can save a lot of disc space as video files are +typically strongly compressed. Using lightly on video data is as simple as pointing +the software at an input directory where one or more videos are stored. The package will +automatically detect all video files and index them so that each frame can be accessed. + +An example for an input directory with videos could look like this: + +.. code-block:: bash + + data/ + +-- my_video_1.mov + +-- my_video_2.mp4 + +-- my_video_3.avi + +The example also shows the currently supported video file formats (.mov, .mp4, and .avi). +To upload the three videos from above to the platform, you can use + +.. code-block:: bash + + lightly-upload token='123' dataset_id='XYZ' input_dir='data/' + +All other operations (like training a self-supervised model and embedding the frames individually) +also work on video data. Give it a try! + +.. note:: + + Randomly accessing video frames is slower compared to accessing the extracted frames on disc. However, + by working directly on video files, one can save a lot of disc space because the frames do not have to + be exctracted beforehand. + Torchvision Datasets -------------------- From f0149be4809c8ec75f8b18d435c66595eec4c033 Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Fri, 20 Nov 2020 09:18:44 +0000 Subject: [PATCH 13/14] Change version to 1.0.4 --- lightly/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightly/__init__.py b/lightly/__init__.py index c47754d00..0ce1ab074 100644 --- a/lightly/__init__.py +++ b/lightly/__init__.py @@ -15,7 +15,7 @@ # All Rights Reserved __name__ = 'lightly' -__version__ = '1.0.3' +__version__ = '1.0.4' try: From efa48326f9c4b885181a19960d2fdfb764b9b21d Mon Sep 17 00:00:00 2001 From: philippmwirth Date: Fri, 20 Nov 2020 16:19:59 +0000 Subject: [PATCH 14/14] Update docker documentation for videos --- .../docker/configuration/configuration.rst | 3 ++ .../docker/getting_started/first_steps.rst | 50 +++++++++++++++++++ 2 files changed, 53 insertions(+) diff --git a/docs/source/docker/configuration/configuration.rst b/docs/source/docker/configuration/configuration.rst index 291d84144..94a0b47de 100644 --- a/docs/source/docker/configuration/configuration.rst +++ b/docs/source/docker/configuration/configuration.rst @@ -34,6 +34,9 @@ The following are parameters which can be passed to the container: # remove exact duplicates remove_exact_duplicates: True + # dump the final dataset to the output directory + dump_dataset: False + # pass checkpoint checkpoint: '' diff --git a/docs/source/docker/getting_started/first_steps.rst b/docs/source/docker/getting_started/first_steps.rst index bc9c419db..caf30716f 100644 --- a/docs/source/docker/getting_started/first_steps.rst +++ b/docs/source/docker/getting_started/first_steps.rst @@ -185,6 +185,56 @@ move the embeddings file to the shared directory, and specify the filename like stopping_condition.n_samples=0.3 \ embeddings=my_embeddings.csv +Sampling from Video Files +-------------------------- +In case you are working with video files, it is possible to point the docker container +directly to the video files. This prevents the need to extract the individual frames beforehand. +To do so, simply store all videos you want to work with in a single directory, the lightly software +will automatically load all frames from the videos. + +.. code-block:: console + + # work on a single video + data/ + +-- my_video.mp4 + + # work on several videos + data/ + +-- my_video_1.mp4 + +-- my_video_2.avi + +As you can see, the videos do not need to be in the same file format. An example command for a folder +structure as shown above could then look like this: + +.. code-block:: console + + docker run --gpus all --rm -it \ + -v INPUT_DIR:/home/input_dir:ro \ + -v SHARED_DIR:/home/shared_dir:ro \ + -v OUTPUT_DIR:/home/output_dir \ + lightly/sampling:latest \ + token=MYAWESOMETOKEN \ + stopping_condition.n_samples=0.3 + +Where INPUT_DIR is the path to the directory containing the video files. + +Removing Exact Duplicates +--------------------------- +With the docker solution, it is possible to remove **only exact duplicates** from the dataset. For this, +simply set the stopping condition `n_samples` to 1.0 (which translates to 100% of the data). The exact command is: + +.. code-block:: console + + docker run --gpus all --rm -it \ + -v INPUT_DIR:/home/input_dir:ro \ + -v SHARED_DIR:/home/shared_dir:ro \ + -v OUTPUT_DIR:/home/output_dir \ + lightly/sampling:latest \ + token=MYAWESOMETOKEN \ + remove_exact_duplicates=True \ + stopping_condition.n_samples=1. + + Reporting -----------------------------------