diff --git a/python/ray/_private/runtime_env/packaging.py b/python/ray/_private/runtime_env/packaging.py index 83efeb8953e19..87412e06abf13 100644 --- a/python/ray/_private/runtime_env/packaging.py +++ b/python/ray/_private/runtime_env/packaging.py @@ -628,6 +628,7 @@ def upload_package_if_needed( package_file = package_file.with_name( f"{time.time_ns()}_{os.getpid()}_{package_file.name}" ) + create_package( module_path, package_file, @@ -656,6 +657,7 @@ async def download_and_unpack_package( base_directory: str, gcs_aio_client: Optional["GcsAioClient"] = None, # noqa: F821 logger: Optional[logging.Logger] = default_logger, + overwrite: bool = False, ) -> str: """Download the package corresponding to this URI and unpack it if zipped. @@ -668,6 +670,7 @@ async def download_and_unpack_package( directory for the unpacked files. gcs_aio_client: Client to use for downloading from the GCS. logger: The logger to use. + overwrite: If True, overwrite the existing package. Returns: Path to the local directory containing the unpacked package files. @@ -695,10 +698,21 @@ async def download_and_unpack_package( local_dir = get_local_dir_from_uri(pkg_uri, base_directory) assert local_dir != pkg_file, "Invalid pkg_file!" - if local_dir.exists(): + + download_package: bool = True + if local_dir.exists() and not overwrite: + download_package = False assert local_dir.is_dir(), f"{local_dir} is not a directory" - else: + elif local_dir.exists(): + logger.info(f"Removing {local_dir} with pkg_file {pkg_file}") + shutil.rmtree(local_dir) + + if download_package: protocol, _ = parse_uri(pkg_uri) + logger.info( + f"Downloading package from {pkg_uri} to {pkg_file} " + f"with protocol {protocol}" + ) if protocol == Protocol.GCS: if gcs_aio_client is None: raise ValueError( diff --git a/python/ray/_private/runtime_env/working_dir.py b/python/ray/_private/runtime_env/working_dir.py index 51a0d2b91a572..6902153235013 100644 --- a/python/ray/_private/runtime_env/working_dir.py +++ b/python/ray/_private/runtime_env/working_dir.py @@ -161,7 +161,11 @@ async def create( logger: logging.Logger = default_logger, ) -> int: local_dir = await download_and_unpack_package( - uri, self._resources_dir, self._gcs_aio_client, logger=logger + uri, + self._resources_dir, + self._gcs_aio_client, + logger=logger, + overwrite=True, ) return get_directory_size_bytes(local_dir) diff --git a/python/ray/tests/test_runtime_env_working_dir.py b/python/ray/tests/test_runtime_env_working_dir.py index e667b0c712b10..35f9d1390c406 100644 --- a/python/ray/tests/test_runtime_env_working_dir.py +++ b/python/ray/tests/test_runtime_env_working_dir.py @@ -45,6 +45,35 @@ def insert_test_dir_in_pythonpath(): yield +@pytest.mark.asyncio +async def test_working_dir_cleanup(tmpdir, ray_start_regular): + gcs_aio_client = gcs_utils.GcsAioClient( + address=ray.worker.global_worker.gcs_client.address + ) + + plugin = WorkingDirPlugin(tmpdir, gcs_aio_client) + await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext()) + + files = os.listdir(f"{tmpdir}/working_dir_files") + + # Iterate over the files and storing creation metadata. + creation_metadata = {} + for file in files: + file_metadata = os.stat(f"{tmpdir}/working_dir_files/{file}") + creation_time = file_metadata.st_ctime + creation_metadata[file] = creation_time + + time.sleep(1) + + await plugin.create(HTTPS_PACKAGE_URI, {}, RuntimeEnvContext()) + files = os.listdir(f"{tmpdir}/working_dir_files") + + for file in files: + file_metadata = os.stat(f"{tmpdir}/working_dir_files/{file}") + creation_time_after = file_metadata.st_ctime + assert creation_metadata[file] != creation_time_after + + @pytest.mark.asyncio async def test_create_delete_size_equal(tmpdir, ray_start_regular): """Tests that `create` and `delete_uri` return the same size for a URI.""" diff --git a/python/ray/train/_internal/storage.py b/python/ray/train/_internal/storage.py index 05970988862e3..46119176c713e 100644 --- a/python/ray/train/_internal/storage.py +++ b/python/ray/train/_internal/storage.py @@ -250,7 +250,7 @@ def _list_at_fs_path( selector = pyarrow.fs.FileSelector(fs_path, allow_not_found=True, recursive=False) return [ - os.path.relpath(file_info.path.lstrip("/"), start=fs_path.lstrip("/")) + os.path.relpath(os.path.abspath(file_info.path), start=os.path.abspath(fs_path)) for file_info in fs.get_file_info(selector) if file_filter(file_info) ]