Skip to content

Commit

Permalink
Add OCI Cloud Storage support (#86)
Browse files Browse the repository at this point in the history
* Add OCI Cloud Storage support

* Fixed remote path and CIFAR10 datatype

* Fixed comment
  • Loading branch information
karan6181 committed Dec 7, 2022
1 parent 87bd2ac commit 1c867fe
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
'transformers>=4.21.3,<5',
'xxhash>=3.0.0,<4',
'zstd>=1.5.2.5,<2',
'oci>=2.88,<3',
]

extra_deps = {}
Expand Down
27 changes: 27 additions & 0 deletions streaming/base/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,31 @@ def download_from_gcs(remote: str, local: str) -> None:
raise FileNotFoundError(f'Object {remote} not found.') from e


def download_from_oci(remote: str, local: str) -> None:
"""Download a file from remote OCI to local.
Args:
remote (str): Remote path (OCI).
local (str): Local path (local filesystem).
"""
import oci
config = oci.config.from_file()
client = oci.object_storage.ObjectStorageClient(
config=config, retry_strategy=oci.retry.DEFAULT_RETRY_STRATEGY)
namespace = client.get_namespace().data
obj = urllib.parse.urlparse(remote)
if obj.scheme != 'oci':
raise ValueError(f'Expected obj.scheme to be "oci", got {obj.scheme} for remote={remote}')

bucket_name = obj.netloc.split('@' + namespace)[0]
# Remove leading and trailing forward slash from string
object_path = obj.path.strip('/')
object_details = client.get_object(namespace, bucket_name, object_path)
with open(local, 'wb') as f:
for chunk in object_details.data.raw.stream(2048**2, decode_content=False):
f.write(chunk)


def download_from_local(remote: str, local: str) -> None:
"""Download a file from remote to local.
Expand Down Expand Up @@ -178,6 +203,8 @@ def download(remote: Optional[str], local: str, timeout: float):
download_from_sftp(remote, local)
elif remote.startswith('gs://'):
download_from_gcs(remote, local)
elif remote.startswith('oci://'):
download_from_oci(remote, local)
else:
download_from_local(remote, local)

Expand Down
2 changes: 1 addition & 1 deletion streaming/vision/convert/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def convert_image_class_dataset(dataset: Dataset,
'y': 'int',
}
hashes = hashes or []
indices = np.random.permutation(len(dataset)) # pyright: ignore
indices = np.random.permutation(len(dataset)).tolist() # pyright: ignore
if progbar:
indices = tqdm(indices, leave=leave)
split_dir = os.path.join(root, split)
Expand Down

0 comments on commit 1c867fe

Please sign in to comment.