diff --git a/ir_datasets/util/download.py b/ir_datasets/util/download.py index 94d0ccce..dd2dc202 100644 --- a/ir_datasets/util/download.py +++ b/ir_datasets/util/download.py @@ -1,6 +1,7 @@ import json import pkgutil import os +import sys from pathlib import Path import atexit from collections import deque @@ -20,6 +21,20 @@ class BaseDownload: def stream(self): raise NotImplementedError() +class GoogleCloudBucketStream(BaseDownload): + def __init__(self, url, tries=None): + self.uri = url.replace("https://storage.googleapis.com/", "gs://") + self.tries = tries + + def __repr__(self): + return f'GoogleCloudBucketStream({repr(self.uri)}, tries={self.tries})' + + @contextlib.contextmanager + def stream(self): + import subprocess + proc = subprocess.Popen(['gsutil', 'cat', self.uri], stdout=subprocess.PIPE) + with io.BufferedReader(proc.stdout, buffer_size=io.DEFAULT_BUFFER_SIZE) as stream: + yield stream class GoogleDriveDownload(BaseDownload): def __init__(self, url, tries=None): @@ -333,7 +348,9 @@ def __getitem__(self, key): local_msg = (f'If you have a local copy of {dlc["url"]}, you can symlink it here ' f'to avoid downloading it again: {local_path}') sources.append(LocalDownload(local_path, local_msg, mkdir=False)) - if dlc['url'].startswith('https://drive.google.com/'): + if dlc['url'].startswith("https://storage.googleapis.com/") and 'google.colab' in sys.modules: + sources.append(GoogleCloudBucketStream(dlc['url'], **download_args)) + elif dlc['url'].startswith('https://drive.google.com/'): sources.append(GoogleDriveDownload(dlc['url'], **download_args)) else: sources.append(RequestsDownload(dlc['url'], **download_args))