Skip to content

Commit

Permalink
Merge pull request #261 from cmacdonald/gsutil
Browse files Browse the repository at this point in the history
detect google colab and use gsutil (for NQ)
  • Loading branch information
seanmacavaney authored Apr 16, 2024
2 parents c4b9b98 + 2cd9cb8 commit e4bf818
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion ir_datasets/util/download.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import json
import pkgutil
import os
import sys
from pathlib import Path
import atexit
from collections import deque
Expand All @@ -20,6 +21,20 @@ class BaseDownload:
def stream(self):
raise NotImplementedError()

class GoogleCloudBucketStream(BaseDownload):
def __init__(self, url, tries=None):
self.uri = url.replace("https://storage.googleapis.com/", "gs://")
self.tries = tries

def __repr__(self):
return f'GoogleCloudBucketStream({repr(self.uri)}, tries={self.tries})'

@contextlib.contextmanager
def stream(self):
import subprocess
proc = subprocess.Popen(['gsutil', 'cat', self.uri], stdout=subprocess.PIPE)
with io.BufferedReader(proc.stdout, buffer_size=io.DEFAULT_BUFFER_SIZE) as stream:
yield stream

class GoogleDriveDownload(BaseDownload):
def __init__(self, url, tries=None):
Expand Down Expand Up @@ -333,7 +348,9 @@ def __getitem__(self, key):
local_msg = (f'If you have a local copy of {dlc["url"]}, you can symlink it here '
f'to avoid downloading it again: {local_path}')
sources.append(LocalDownload(local_path, local_msg, mkdir=False))
if dlc['url'].startswith('https://drive.google.com/'):
if dlc['url'].startswith("https://storage.googleapis.com/") and 'google.colab' in sys.modules:
sources.append(GoogleCloudBucketStream(dlc['url'], **download_args))
elif dlc['url'].startswith('https://drive.google.com/'):
sources.append(GoogleDriveDownload(dlc['url'], **download_args))
else:
sources.append(RequestsDownload(dlc['url'], **download_args))
Expand Down

0 comments on commit e4bf818

Please sign in to comment.