Skip to content

Commit

Permalink
tiny improvement (#19341)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton authored Jan 24, 2024
1 parent 577bd85 commit 0a75d3b
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
43 changes: 30 additions & 13 deletions src/lightning/data/streaming/downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# limitations under the License.
import os
import shutil
import subprocess
from abc import ABC
from typing import Any, Dict, List
from urllib import parse
Expand Down Expand Up @@ -40,29 +41,45 @@ def download_file(self, remote_chunkpath: str, local_chunkpath: str) -> None:
class S3Downloader(Downloader):
def __init__(self, remote_dir: str, cache_dir: str, chunks: List[Dict[str, Any]]):
super().__init__(remote_dir, cache_dir, chunks)
self._client = S3Client()
self._s5cmd_available = os.system("s5cmd > /dev/null 2>&1") == 0

if not self._s5cmd_available:
self._client = S3Client()

def download_file(self, remote_filepath: str, local_filepath: str) -> None:
obj = parse.urlparse(remote_filepath)

if obj.scheme != "s3":
raise ValueError(f"Expected obj.scheme to be `s3`, instead, got {obj.scheme} for remote={remote_filepath}")

from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}
if os.path.exists(local_filepath):
return

try:
with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
with FileLock(local_filepath + ".lock", timeout=0):
if self._s5cmd_available:
proc = subprocess.Popen(
f"s5cmd --numworkers 64 cp {remote_filepath} {local_filepath}",
shell=True,
stdout=subprocess.PIPE,
)
proc.wait()
else:
from boto3.s3.transfer import TransferConfig

extra_args: Dict[str, Any] = {}

# try:
# with FileLock(local_filepath + ".lock", timeout=1):
if not os.path.exists(local_filepath):
# Issue: https://github.com/boto/boto3/issues/3113
self._client.client.download_file(
obj.netloc,
obj.path.lstrip("/"),
local_filepath,
ExtraArgs=extra_args,
Config=TransferConfig(use_threads=False),
)
except Timeout:
# another process is responsible to download that file, continue
pass
Expand Down
13 changes: 13 additions & 0 deletions tests/tests_data/streaming/test_downloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os
from unittest.mock import MagicMock

from lightning.data.streaming.downloader import S3Downloader, subprocess


def test_s3_downloader_fast(tmpdir, monkeypatch):
monkeypatch.setattr(os, "system", MagicMock(return_value=0))
popen_mock = MagicMock()
monkeypatch.setattr(subprocess, "Popen", MagicMock(return_value=popen_mock))
downloader = S3Downloader(tmpdir, tmpdir, [])
downloader.download_file("s3://random_bucket/a.txt", os.path.join(tmpdir, "a.txt"))
popen_mock.wait.assert_called()

0 comments on commit 0a75d3b

Please sign in to comment.