From 8668e8a38c9daa5a164c7d70401cdd6928e880cf Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Thu, 2 Mar 2023 15:30:02 +0100 Subject: [PATCH 1/9] File downloader and extractor utility Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/download.py | 240 +++++++++++++ src/power_grid_model_io/utils/zip.py | 79 +++++ tests/unit/utils/test_download.py | 398 ++++++++++++++++++++++ tests/unit/utils/test_zip.py | 125 +++++++ tests/utils.py | 11 + 5 files changed, 853 insertions(+) create mode 100644 src/power_grid_model_io/utils/download.py create mode 100644 src/power_grid_model_io/utils/zip.py create mode 100644 tests/unit/utils/test_download.py create mode 100644 tests/unit/utils/test_zip.py diff --git a/src/power_grid_model_io/utils/download.py b/src/power_grid_model_io/utils/download.py new file mode 100644 index 00000000..a1ad983e --- /dev/null +++ b/src/power_grid_model_io/utils/download.py @@ -0,0 +1,240 @@ +# SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project +# +# SPDX-License-Identifier: MPL-2.0 +""" +Helper functions to download (and store) files from the internet + +The most simple (and intended) usage is: +url = "http://141.51.193.167/simbench/gui/usecase/download/?simbench_code=1-complete_data-mixed-all-0-sw&format=csv" +zip_file_path = download(url) + +It will download the zip file 1-complete_data-mixed-all-0-sw.zip to a folder in you systems temp dir; for example +"/tmp/1-complete_data-mixed-all-0-sw.zip". + +Another convenience function is download_and_extract(): + +csv_dir_path = download_and_extract(url) + +This downloads the zip file as described above, and then it extracts the files there as well, in a folder which +corresponds to the zip file name ("/tmp/1-complete_data-mixed-all-0-sw/" in our example), and it returns the path to +that directory. By default, it will not re-download or re-extract the zip file as long as the files exist in your +temp dir. Your temp dir is typically emptied whe you reboot your computer. + +""" + +import base64 +import hashlib +import re +import tempfile +from dataclasses import dataclass +from pathlib import Path +from shutil import rmtree as remove_dir +from typing import Optional, Union +from urllib import request + +import structlog +from tqdm import tqdm + +from power_grid_model_io.utils.zip import extract + +_log = structlog.get_logger(__name__) + + +@dataclass +class ResponseInfo: + """ + Struct to store response information extracted from the response header + """ + + status: int + file_name: Optional[str] = None + file_size: Optional[int] = None + + +class DownloadProgressHook: # pylint: disable=too-few-public-methods + """ + Report hook for request.urlretrieve() to update a progress bar based on the amount of downloaded blocks + """ + + def __init__(self, progress_bar: tqdm): + """ + Report hook for request.urlretrieve() to update a progress bar based on the amount of downloaded blocks + + Args: + progress_bar: A tqdb progress bar + """ + self._progress_bar = progress_bar + self._last_block = 0 + + def __call__(self, block_num: int, block_size: int, file_size: int) -> None: + """ + Args: + block_num: The last downloaded block number + block_size: The block size in bytes + file_size: The file size in bytes (may be 0 in the first call) + + """ + if file_size > 0: + self._progress_bar.total = file_size + self._progress_bar.update((block_num - self._last_block) * block_size) + self._last_block = block_num + + +def download_and_extract( + url: str, dir_path: Optional[Path] = None, file_name: Optional[Union[str, Path]] = None, overwrite: bool = False +) -> Path: + """ + Download a file from a URL and store it locally, extract the contents and return the path to the contents. + + Args: + url: The url to the .zip file + dir_path: An optional dir path to store the downloaded file. If no dir_path is given the current working dir + will be used. + file_name: An optional file name (or path relative to dir_path). If no file_name is given, a file name is + generated based on the url + overwrite: Should we download the file, even if we have downloaded already (and the file size still matches)? + Be careful with this option, as it will remove files from your drive irreversibly! + + Returns: + The path to the downloaded file + """ + + # Download the file and use the file name as the base name for the extraction directory + src_file_path = download(url=url, file_name=file_name, dir_path=dir_path, overwrite=overwrite) + dst_dir_path = src_file_path.with_suffix("") + + # If we explicitly want to overwrite the extracted files, remove the + if overwrite and dst_dir_path.is_dir(): + remove_dir(dst_dir_path) + + # Extract the files and return the path of the extraction directory + return extract(src_file_path=src_file_path, dst_dir_path=dst_dir_path, skip_if_exists=not overwrite) + + +def download( + url: str, file_name: Optional[Union[str, Path]] = None, dir_path: Optional[Path] = None, overwrite: bool = False +) -> Path: + """ + Download a file from a URL and store it locally + + Args: + url: The url to the file + file_name: An optional file name (or path relative to dir_path). If no file_name is given, a file name is + generated based on the url + dir_path: An optional dir path to store the downloaded file. If no dir_path is given the current working dir + will be used. + overwrite: Should we download the file, even if we have downloaded already (and the file size still matches)? + + Returns: + The path to the downloaded file + """ + + # get the response info, if the status is not 200 + info = get_response_info(url=url) + if info.status != 200: + raise IOError(f"Could not download from URL, status={info.status}") + + if file_name is None and info.file_name: + file_name = info.file_name + + file_path = get_download_path(dir_path=dir_path, file_name=file_name, unique_key=url) + log = _log.bind(url=url, file_path=file_path) + + if file_path.is_file(): + if overwrite: + log.debug("Forced re-downloading existing file") + # Don't remove the existing file just yet... Let's first see if we can download a new version. + else: + local_size = file_path.stat().st_size + if local_size == info.file_size: + log.debug("Skip downloading existing file") + return file_path + log.debug( + "Re-downloading existing file, because the size has changed", + local_size=local_size, + remote_size=info.file_size, + ) + else: + log.debug("Downloading file") + + # Download to a temp file first, so the results are not stored if the transfer fails + with tqdm(desc="Downloading", unit="B", unit_scale=True, leave=True) as progress_bar: + report_hook = DownloadProgressHook(progress_bar) + temp_file, _headers = request.urlretrieve(url, reporthook=report_hook) + + # Check if the file contains any content + temp_path = Path(temp_file) + if temp_path.stat().st_size == 0: + log.warning("Downloaded an empty file") + + # Remove the file, if it already exists + file_path.unlink(missing_ok=True) + + # Move the file to it's final destination + file_path.parent.mkdir(parents=True, exist_ok=True) + temp_path.rename(file_path) + log.debug("Downloaded file", file_size=file_path.stat().st_size) + + return file_path + + +def get_response_info(url: str) -> ResponseInfo: + """ + Retrieve the file size of a given URL (based on it's header) + + Args: + url: The url to the file + + Return: + The file size in bytes + """ + with request.urlopen(url) as context: + status = context.status + headers = context.headers + file_size = int(headers["Content-Length"]) if "Content-Length" in headers else None + matches = re.findall(r"filename=\"(.+)\"", headers.get("Content-Disposition", "")) + file_name = matches[0] if matches else None + + return ResponseInfo(status=status, file_size=file_size, file_name=file_name) + + +def get_download_path( + dir_path: Optional[Path] = None, + file_name: Optional[Union[str, Path]] = None, + unique_key: Optional[str] = None, +) -> Path: + """ + Determine the file path based on dir_path, file_name and/or data + + Args: + dir_path: An optional dir path to store the downloaded file. If no dir_path is given the system's temp dir + will be used. If omitted, the tempfolder is used. + file_name: An optional file name (or path relative to dir_path). If no file_name is given, a file name is + generated based on the unique key (e.g. an url) + unique_key: A unique string that can be used to generate a filename (e.g. a url). + """ + + # If no file_name is given, generate a file name + if file_name is None: + if unique_key is None: + raise ValueError("Supply data in order to auto generate a download path.") + + sha256 = hashlib.sha256() + sha256.update(unique_key.encode()) + hash_str = base64.b64encode(sha256.digest()).decode("ascii") + hash_str = hash_str.replace("/", "_").replace("+", "-").rstrip("=") + file_name = Path(f"{hash_str}.download") + + # If no dir_path is given, use the system's designated folder for temporary files + elif dir_path is None: + dir_path = Path(tempfile.gettempdir()) + + # Combine the two paths + assert file_name is not None + file_path = (dir_path / file_name) if dir_path else Path(file_name) + + # If the file_path exists, it should be a file (not a dir) + if file_path.exists() and not file_path.is_file(): + raise ValueError(f"Invalid file path: {file_path}") + + return file_path.resolve() diff --git a/src/power_grid_model_io/utils/zip.py b/src/power_grid_model_io/utils/zip.py new file mode 100644 index 00000000..cf001a83 --- /dev/null +++ b/src/power_grid_model_io/utils/zip.py @@ -0,0 +1,79 @@ +# SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project +# +# SPDX-License-Identifier: MPL-2.0 +""" +Helper function to extract zip files + +csv_dir_path = extract("/tmp/1-complete_data-mixed-all-0-sw.zip") + +This extracts the files, in a folder which corresponds to the zip file name ("/tmp/1-complete_data-mixed-all-0-sw/" in +our example), and it returns the path to that directory. By default, it will not re-download or re-extract the zip +file as long as the files exist. + +""" + +import zipfile +from pathlib import Path +from typing import Optional + +import structlog +from tqdm import tqdm + +_log = structlog.get_logger(__name__) + + +def extract(src_file_path: Path, dst_dir_path: Optional[Path] = None, skip_if_exists=False) -> Path: + """ + Extract a .zip file and return the destination dir + + Args: + src_file_path: The .zip file to extract. + dst_dir_path: An optional destination path. If none is given, the src_file_path without .zip extension is used. + skip_if_exists: Skip existing files, otherwise raise an exception when a file exists. + + Returns: The path where the files are extracted + + """ + if src_file_path.suffix.lower() != ".zip": + raise ValueError(f"Only files with .zip extension are supported, got {src_file_path.name}") + + if dst_dir_path is None: + dst_dir_path = src_file_path.with_suffix("") + + log = _log.bind(src_file_path=src_file_path, dst_dir_path=dst_dir_path) + + if dst_dir_path.exists(): + if not dst_dir_path.is_dir(): + raise NotADirectoryError(f"Destination dir {dst_dir_path} exists and is not a directory") + + # Create the destination directory + dst_dir_path.mkdir(parents=True, exist_ok=True) + + # Extract per file, so we can show a progress bar + with zipfile.ZipFile(src_file_path, "r") as zip_file: + file_list = zip_file.namelist() + for file_path in tqdm(desc="Extracting", iterable=file_list, total=len(file_list), unit="file", leave=True): + dst_file_path = dst_dir_path / file_path + if dst_file_path.exists() and dst_file_path.stat().st_size > 0: + if skip_if_exists: + log.debug("Skip file extraction, destination file exists", dst_file_path=dst_file_path) + continue + raise FileExistsError(f"Destination file {dst_dir_path / file_path} exists and is not empty") + zip_file.extract(member=file_path, path=dst_dir_path) + + # Zip files often contain a single directory with the same name as the zip file. + # In that case, return the dir to that directory instead of the root dir + only_item: Optional[Path] = None + for item in dst_dir_path.iterdir(): + # If only_item is None, this is the first iteration, so item may be the only item + if only_item is None: + only_item = item + # Else, if only_item is not None, there are more than one items in the root of the directory. + # This means hat there is no 'only_item' and we can stop the loop + else: + only_item = None + break + if only_item and only_item.is_dir() and only_item.name == src_file_path.stem: + dst_dir_path = only_item + + return dst_dir_path.resolve() diff --git a/tests/unit/utils/test_download.py b/tests/unit/utils/test_download.py new file mode 100644 index 00000000..7e886703 --- /dev/null +++ b/tests/unit/utils/test_download.py @@ -0,0 +1,398 @@ +# SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project +# +# SPDX-License-Identifier: MPL-2.0 +import tempfile +from collections import namedtuple +from pathlib import Path +from unittest.mock import ANY, MagicMock, patch + +import pytest +import structlog.testing + +from power_grid_model_io.utils.download import ( + DownloadProgressHook, + ResponseInfo, + download, + download_and_extract, + get_download_path, + get_response_info, +) + +from ...utils import assert_log_exists + +Response = namedtuple("Response", ["status", "headers"]) + + +@pytest.fixture() +def temp_dir(): + with tempfile.TemporaryDirectory() as tmp: + yield Path(tmp).resolve() + + +def make_file(file_path: Path, file_size: int = 0): + with open(file_path, "wb") as fp: + fp.write(b"\0" * file_size) + + +def test_progress_hook(): + # Arrange + progress_bar = MagicMock() + progress_bar.total = None + hook = DownloadProgressHook(progress_bar) + + # Act (block_num, block_size, file_size) + hook(2, 10, 0) + assert progress_bar.total is None # total is not updated + + hook(3, 10, 123) + assert progress_bar.total == 123 # total is updated + + hook(6, 10, 123) + + # Assert + assert progress_bar.update.call_args_list[0].args == (20,) + assert progress_bar.update.call_args_list[1].args == (10,) + assert progress_bar.update.call_args_list[2].args == (30,) + + +@patch("power_grid_model_io.utils.download.extract") +@patch("power_grid_model_io.utils.download.download") +def test_download_and_extract__paths(mock_download: MagicMock, mock_extract: MagicMock, temp_dir: Path): + # Arrange + url = MagicMock() + dir_path = MagicMock() + file_path = MagicMock() + src_file_path = temp_dir / "data.zip" + dst_dir_path = temp_dir / "data" + extract_dir_path = MagicMock() + + mock_download.return_value = src_file_path + mock_extract.return_value = extract_dir_path + + # Act + result = download_and_extract(url=url, dir_path=dir_path, file_name=file_path) + + # Assert + mock_download.assert_called_once_with(url=url, file_name=file_path, dir_path=dir_path, overwrite=False) + mock_extract.assert_called_once_with(src_file_path=src_file_path, dst_dir_path=dst_dir_path, skip_if_exists=True) + assert result == extract_dir_path + + +@patch("power_grid_model_io.utils.download.extract") +@patch("power_grid_model_io.utils.download.download") +def test_download_and_extract__no_paths(mock_download: MagicMock, mock_extract: MagicMock, temp_dir: Path): + # Arrange + url = MagicMock() + src_file_path = temp_dir / "data.zip" + dst_dir_path = temp_dir / "data" + + mock_download.return_value = src_file_path + + # Act + download_and_extract(url=url) + + # Assert + mock_download.assert_called_once_with(url=url, file_name=None, dir_path=None, overwrite=False) + mock_extract.assert_called_once_with(src_file_path=src_file_path, dst_dir_path=dst_dir_path, skip_if_exists=True) + + +@patch("power_grid_model_io.utils.download.extract", new=MagicMock) +@patch("power_grid_model_io.utils.download.download") +def test_download_and_extract__overwrite(mock_download: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "data.zip" + mock_download.return_value = src_file_path + + dst_dir_path = temp_dir / "data" + dst_dir_path.mkdir() + + # Act / Assert + download_and_extract(url=MagicMock(), overwrite=False) + assert dst_dir_path.is_dir() + + # Act / Assert (dir does exist, overwrite = True) + download_and_extract(url=MagicMock(), overwrite=True) + assert not dst_dir_path.exists() + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook") +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm") +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download( + mock_info: MagicMock, + mock_download_path: MagicMock, + mock_tqdm: MagicMock, + mock_urlretrieve: MagicMock, + mock_hook: MagicMock, + temp_dir: Path, +): + # Arrange + url = "https://www.source.com" + dir_path = temp_dir / "data" + file_path = temp_dir / "data.zip" + temp_file = temp_dir / "data.download" + download_path = temp_dir / "data.zip" + + def urlretrieve(*_args, **_kwargs): + make_file(temp_file, 100) + return temp_file, None + + mock_info.return_value = ResponseInfo(status=200, file_size=100, file_name="remote.zip") + mock_download_path.return_value = download_path + mock_urlretrieve.side_effect = urlretrieve + + # Act / Assert + with structlog.testing.capture_logs() as capture: + result = download(url=url, file_name=file_path, dir_path=dir_path) + assert_log_exists(capture, "debug", "Downloading file") + + # Assert + mock_download_path.assert_called_once_with( + dir_path=dir_path, file_name=file_path, unique_key="https://www.source.com" + ) + mock_tqdm.assert_called_once() + mock_hook.assert_called_once_with(mock_tqdm.return_value.__enter__.return_value) + mock_urlretrieve.assert_called_once_with(url, reporthook=mock_hook.return_value) + assert result == file_path + assert result.is_file() + assert result.stat().st_size == 100 + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook", new=MagicMock()) +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm", new=MagicMock()) +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__auto_file_name( + mock_info: MagicMock, mock_download_path: MagicMock, mock_urlretrieve: MagicMock, temp_dir: Path +): + # Arrange + temp_file = temp_dir / "data.download" + download_path = temp_dir / "data.zip" + + def urlretrieve(*_args, **_kwargs): + make_file(temp_file, 100) + return temp_file, None + + mock_info.return_value = ResponseInfo(status=200, file_size=None, file_name="remote.zip") + mock_download_path.return_value = download_path + mock_urlretrieve.side_effect = urlretrieve + + # Act + download(url=MagicMock()) + + # Assert + mock_download_path.assert_called_once_with(dir_path=None, file_name="remote.zip", unique_key=ANY) + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook", new=MagicMock()) +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm", new=MagicMock()) +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__empty_file( + mock_info: MagicMock, mock_download_path: MagicMock, mock_urlretrieve: MagicMock, temp_dir: Path +): + # Arrange + temp_file = temp_dir / "data.download" + download_path = temp_dir / "data.zip" + + def urlretrieve(*_args, **_kwargs): + with open(temp_file, "wb"): + pass + return temp_file, None + + mock_info.return_value = ResponseInfo(status=200, file_size=None, file_name="remote.zip") + mock_download_path.return_value = download_path + mock_urlretrieve.side_effect = urlretrieve + + # Act / Assert + with structlog.testing.capture_logs() as capture: + download(url=MagicMock()) + assert_log_exists(capture, "debug", "Downloading file") + assert_log_exists(capture, "warning", "Downloaded an empty file") + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook", new=MagicMock()) +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm", new=MagicMock()) +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__skip_existing_file( + mock_info: MagicMock, mock_download_path: MagicMock, mock_urlretrieve: MagicMock, temp_dir: Path +): + # Arrange + download_path = temp_dir / "data.zip" + make_file(download_path, 100) + + mock_info.return_value = ResponseInfo(status=200, file_size=100, file_name="remote.zip") + mock_download_path.return_value = download_path + + # Act + with structlog.testing.capture_logs() as capture: + download(url=MagicMock()) + assert_log_exists(capture, "debug", "Skip downloading existing file") + + # Assert + mock_urlretrieve.assert_not_called() + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook", new=MagicMock()) +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm", new=MagicMock()) +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__update_file( + mock_info: MagicMock, mock_download_path: MagicMock, mock_urlretrieve: MagicMock, temp_dir: Path +): + # Arrange + temp_file = temp_dir / "data.download" + download_path = temp_dir / "data.zip" + make_file(download_path, 100) + + def urlretrieve(*_args, **_kwargs): + make_file(temp_file, 101) + return temp_file, None + + mock_info.return_value = ResponseInfo(status=200, file_size=101, file_name="remote.zip") + mock_download_path.return_value = download_path + mock_urlretrieve.side_effect = urlretrieve + + # Act / Assert + with structlog.testing.capture_logs() as capture: + result = download(url=MagicMock()) + assert_log_exists(capture, "debug", "Re-downloading existing file, because the size has changed") + + # Assert + assert result == download_path + assert result.is_file() + assert result.stat().st_size == 101 + + +@patch("power_grid_model_io.utils.download.DownloadProgressHook", new=MagicMock()) +@patch("power_grid_model_io.utils.download.request.urlretrieve") +@patch("power_grid_model_io.utils.download.tqdm", new=MagicMock()) +@patch("power_grid_model_io.utils.download.get_download_path") +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__overwrite( + mock_info: MagicMock, mock_download_path: MagicMock, mock_urlretrieve: MagicMock, temp_dir: Path +): + # Arrange + temp_file = temp_dir / "data.download" + download_path = temp_dir / "data.zip" + make_file(download_path, 100) + + def urlretrieve(*_args, **_kwargs): + make_file(temp_file, 100) + return temp_file, None + + mock_info.return_value = ResponseInfo(status=200, file_size=100, file_name="remote.zip") + mock_download_path.return_value = download_path + mock_urlretrieve.side_effect = urlretrieve + + # Act / Assert + with structlog.testing.capture_logs() as capture: + result = download(url=MagicMock(), overwrite=True) + assert_log_exists(capture, "debug", "Forced re-downloading existing file") + + # Assert + assert result == download_path + assert result.is_file() + assert result.stat().st_size == 100 + + +@patch("power_grid_model_io.utils.download.get_response_info") +def test_download__status_error(mock_info: MagicMock): + # Arrange + mock_info.return_value = ResponseInfo(status=404, file_size=None, file_name=None) + + # Act / Assert + with pytest.raises(IOError, match=r"Could not download from URL, status=404"): + download(url=MagicMock()) + + +@patch("power_grid_model_io.utils.download.request.urlopen") +def test_get_response_info(mock_urlopen): + # Arrange + headers = {"Content-Length": "456", "Content-Disposition": 'form-data; name="ZipFile"; filename="filename.zip"'} + mock_urlopen.return_value.__enter__.return_value = Response(status=123, headers=headers) + + # Act / Assert + assert get_response_info("") == ResponseInfo(status=123, file_size=456, file_name="filename.zip") + + +@patch("power_grid_model_io.utils.download.request.urlopen") +def test_get_response_info__no_file_name(mock_urlopen): + # Arrange + headers = {"Content-Length": "456", "Content-Disposition": 'form-data; name="ZipFile"'} + mock_urlopen.return_value.__enter__.return_value = Response(status=123, headers=headers) + + # Act / Assert + assert get_response_info("") == ResponseInfo(status=123, file_size=456, file_name=None) + + +@patch("power_grid_model_io.utils.download.request.urlopen") +def test_get_response_info__no_disposition(mock_urlopen): + # Arrange + headers = {"Content-Length": "456"} + Context = namedtuple("Context", ["status", "headers"]) + mock_urlopen.return_value.__enter__.return_value = Context(status=123, headers=headers) + + # Act / Assert + assert get_response_info("") == ResponseInfo(status=123, file_size=456, file_name=None) + + +@patch("power_grid_model_io.utils.download.request.urlopen") +def test_get_response_info__no_length(mock_urlopen): + # Arrange + headers = {"Content-Disposition": 'form-data; name="ZipFile"; filename="filename.zip"'} + mock_urlopen.return_value.__enter__.return_value = Response(status=123, headers=headers) + + # Act / Assert + assert get_response_info("") == ResponseInfo(status=123, file_size=None, file_name="filename.zip") + + +def test_get_download_path(temp_dir: Path): + # Act + path = get_download_path(dir_path=temp_dir, file_name="file_name.zip", unique_key="foo") + + # Assert + assert path == temp_dir / "file_name.zip" + + +def test_get_download_path__auto_dir(): + # Act + path = get_download_path(file_name="file_name.zip") + + # Assert + assert path == Path(tempfile.gettempdir()).resolve() / "file_name.zip" + + +def test_get_download_path__auto_file_name(temp_dir: Path): + # Arrange + # The base64 representation of the sha256 hash of "foo" is LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564= + # The / and + will be replaced with a _ and - character and the trailing = character(s) will be removed. + expected_file_name = "LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564.download" + + # Act + path = get_download_path(dir_path=temp_dir, unique_key="foo") + + # Assert + assert path == temp_dir / expected_file_name + + +def test_get_download_path__missing_data(temp_dir: Path): + # Act / Assert + with pytest.raises(ValueError, match=r"Supply data in order to auto generate a download path\."): + get_download_path(dir_path=temp_dir) + + +def test_get_download_path__invalid_file_path(temp_dir: Path): + # Arrange + (temp_dir / "download").mkdir() + + # Act / Assert + with pytest.raises(ValueError, match=r"Invalid file path:"): + get_download_path(dir_path=temp_dir, file_name="download") diff --git a/tests/unit/utils/test_zip.py b/tests/unit/utils/test_zip.py new file mode 100644 index 00000000..dae83f58 --- /dev/null +++ b/tests/unit/utils/test_zip.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project +# +# SPDX-License-Identifier: MPL-2.0 +import shutil +import tempfile +from pathlib import Path +from unittest.mock import MagicMock, patch + +import pytest +import structlog.testing + +from power_grid_model_io.utils.zip import extract + +from ...utils import MockTqdm, assert_log_exists + +DATA_DIR = Path(__file__).parents[2] / "data" / "zip" +ZIP1 = DATA_DIR / "foo.zip" +ZIP2 = DATA_DIR / "foo-bar.zip" + + +@pytest.fixture() +def temp_dir(): + with tempfile.TemporaryDirectory() as tmp: + yield Path(tmp).resolve() + + +@patch("power_grid_model_io.utils.download.tqdm") +def test_extract(mock_tqdm: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "compressed.zip" + dst_dir_path = temp_dir / "extracted" + shutil.copyfile(ZIP2, src_file_path) + mock_tqdm.side_effect = MockTqdm + + # Act + extract_dir_path = extract(src_file_path=src_file_path, dst_dir_path=dst_dir_path) + + # Assert + assert extract_dir_path == dst_dir_path + assert (dst_dir_path / "foo.txt").is_file() + assert (dst_dir_path / "folder/bar.txt").is_file() + + +@patch("power_grid_model_io.utils.download.tqdm") +def test_extract__auto_dir(mock_tqdm: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "compressed.zip" + shutil.copyfile(ZIP2, src_file_path) + mock_tqdm.side_effect = MockTqdm + + # Act + extract_dir_path = extract(src_file_path=src_file_path) + + # Assert + assert extract_dir_path == temp_dir / "compressed" + assert (temp_dir / "compressed" / "foo.txt").is_file() + assert (temp_dir / "compressed" / "folder" / "bar.txt").is_file() + + +def test_extract__invalid_file_extension(): + # Act / Assert + with pytest.raises(ValueError, match=r"Only files with \.zip extension are supported, got tempfile\.download"): + extract(src_file_path=Path("/tmp/dir/tempfile.download")) + + +def test_extract__invalid_dst_dir(temp_dir: Path): + # Arrange + with open(temp_dir / "notadir.txt", "wb"): + pass + + # Act / Assert + with pytest.raises(NotADirectoryError, match=r"Destination dir .*notadir\.txt exists and is not a directory"): + extract(src_file_path=Path("file.zip"), dst_dir_path=temp_dir / "notadir.txt") + + +@patch("power_grid_model_io.utils.download.tqdm") +def test_extract__file_exists(mock_tqdm: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "compressed.zip" + dst_dir_path = temp_dir / "extracted" + shutil.copyfile(ZIP2, src_file_path) + mock_tqdm.side_effect = MockTqdm + + dst_dir_path.mkdir() + with open(dst_dir_path / "foo.txt", "wb") as fp: + fp.write(b"\0") + + # Act / Assert + with pytest.raises(FileExistsError, match=r"Destination file .*foo\.txt exists and is not empty"): + extract(src_file_path=src_file_path, dst_dir_path=dst_dir_path) + + +@patch("power_grid_model_io.utils.download.tqdm") +def test_extract__skip_if_exists(mock_tqdm: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "compressed.zip" + dst_dir_path = temp_dir / "compressed" + shutil.copyfile(ZIP2, src_file_path) + mock_tqdm.side_effect = MockTqdm + + dst_dir_path.mkdir() + with open(dst_dir_path / "foo.txt", "wb") as fp: + fp.write(b"\0") + + # Act / Assert + with structlog.testing.capture_logs() as capture: + extract(src_file_path=src_file_path, dst_dir_path=dst_dir_path, skip_if_exists=True) + assert_log_exists( + capture, "debug", "Skip file extraction, destination file exists", dst_file_path=dst_dir_path / "foo.txt" + ) + + +@patch("power_grid_model_io.utils.download.tqdm") +def test_extract__return_subdir_path(mock_tqdm: MagicMock, temp_dir: Path): + # Arrange + src_file_path = temp_dir / "foo.zip" + shutil.copyfile(ZIP1, src_file_path) + mock_tqdm.side_effect = MockTqdm + + # Act + extract_dir_path = extract(src_file_path=src_file_path) + + # Assert + assert extract_dir_path == temp_dir / "foo" / "foo" + assert (temp_dir / "foo" / "foo" / "foo.txt").is_file() diff --git a/tests/utils.py b/tests/utils.py index 045686c7..a5ba999d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -268,3 +268,14 @@ def sheet_names(self) -> List[str]: def parse(self, sheet_name: str, **_kwargs) -> pd.DataFrame: return self.data[sheet_name] + + +class MockTqdm: + """To use: for x in tqdm(iterable)""" + + def __init__(self, iterable=None, **kwargs): + self.iterable = iterable + + def __iter__(self): + for item in self.iterable: + yield item From d54986e18b40d04ce8fc306349f88ce1eaff8c7b Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Thu, 2 Mar 2023 15:36:40 +0100 Subject: [PATCH 2/9] Add very small test .zip files Signed-off-by: Bram Stoeller --- tests/data/zip/foo-bar.zip | Bin 0 -> 247 bytes tests/data/zip/foo-bar.zip.license | 3 +++ tests/data/zip/foo.zip | Bin 0 -> 134 bytes tests/data/zip/foo.zip.license | 3 +++ 4 files changed, 6 insertions(+) create mode 100644 tests/data/zip/foo-bar.zip create mode 100644 tests/data/zip/foo-bar.zip.license create mode 100644 tests/data/zip/foo.zip create mode 100644 tests/data/zip/foo.zip.license diff --git a/tests/data/zip/foo-bar.zip b/tests/data/zip/foo-bar.zip new file mode 100644 index 0000000000000000000000000000000000000000..5f7c0598608bcbeb5f92ef6fa083c84039038b06 GIT binary patch literal 247 zcmWIWW@Zs#00FPy#IOdj7;`}&8-)3QI4wUXCACODDX~beq@qMmPcJ1uC%;IcII~0{ zF*mg&0It~usCgfAMm|Ur2(yDV=Yw?lU~k@aG;5Tpa5*A}c7Y<++?D;r3H2?z^;bODIN0079*E2jVe literal 0 HcmV?d00001 diff --git a/tests/data/zip/foo-bar.zip.license b/tests/data/zip/foo-bar.zip.license new file mode 100644 index 00000000..ebe16ce8 --- /dev/null +++ b/tests/data/zip/foo-bar.zip.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project + +SPDX-License-Identifier: MPL-2.0 \ No newline at end of file diff --git a/tests/data/zip/foo.zip b/tests/data/zip/foo.zip new file mode 100644 index 0000000000000000000000000000000000000000..8e13ab608434d4387c6fe15626003c35f659f68f GIT binary patch literal 134 zcmWIWW@Zs#00Eca#ISwL8ToubHVAVAaaw-9J`n4bRFwGS7p3MZWEK>c=IZI`1$Z+u ji7?>S2-M5K$e;kCfB?yu0B=?{kPssfS^{Zv5QhN(qP!L{ literal 0 HcmV?d00001 diff --git a/tests/data/zip/foo.zip.license b/tests/data/zip/foo.zip.license new file mode 100644 index 00000000..ebe16ce8 --- /dev/null +++ b/tests/data/zip/foo.zip.license @@ -0,0 +1,3 @@ +SPDX-FileCopyrightText: 2022 Contributors to the Power Grid Model project + +SPDX-License-Identifier: MPL-2.0 \ No newline at end of file From 8b4fa2a08fe9d68bc645e44deba9b74dc330f1ed Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Thu, 2 Mar 2023 17:03:29 +0100 Subject: [PATCH 3/9] Add tqdm Signed-off-by: Bram Stoeller --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index f6e04ad7..d348bf3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ dependencies = [ "power_grid_model>=1.4", "pyyaml", "structlog", + "tqdm", ] dynamic = ["version"] From d5daa30842ec7ad8e7bbcbb53b8647e0c919702e Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 09:43:56 +0100 Subject: [PATCH 4/9] Fix typos in comments Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/download.py | 8 ++++---- src/power_grid_model_io/utils/zip.py | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/power_grid_model_io/utils/download.py b/src/power_grid_model_io/utils/download.py index a1ad983e..c25a9687 100644 --- a/src/power_grid_model_io/utils/download.py +++ b/src/power_grid_model_io/utils/download.py @@ -18,7 +18,7 @@ This downloads the zip file as described above, and then it extracts the files there as well, in a folder which corresponds to the zip file name ("/tmp/1-complete_data-mixed-all-0-sw/" in our example), and it returns the path to that directory. By default, it will not re-download or re-extract the zip file as long as the files exist in your -temp dir. Your temp dir is typically emptied whe you reboot your computer. +temp dir. Your temp dir is typically emptied when you reboot your computer. """ @@ -43,7 +43,7 @@ @dataclass class ResponseInfo: """ - Struct to store response information extracted from the response header + Data class to store response information extracted from the response header """ status: int @@ -61,7 +61,7 @@ def __init__(self, progress_bar: tqdm): Report hook for request.urlretrieve() to update a progress bar based on the amount of downloaded blocks Args: - progress_bar: A tqdb progress bar + progress_bar: A tqdm progress bar """ self._progress_bar = progress_bar self._last_block = 0 @@ -103,7 +103,7 @@ def download_and_extract( src_file_path = download(url=url, file_name=file_name, dir_path=dir_path, overwrite=overwrite) dst_dir_path = src_file_path.with_suffix("") - # If we explicitly want to overwrite the extracted files, remove the + # If we explicitly want to overwrite the extracted files, remove the destination dir. if overwrite and dst_dir_path.is_dir(): remove_dir(dst_dir_path) diff --git a/src/power_grid_model_io/utils/zip.py b/src/power_grid_model_io/utils/zip.py index cf001a83..0d5622fb 100644 --- a/src/power_grid_model_io/utils/zip.py +++ b/src/power_grid_model_io/utils/zip.py @@ -69,7 +69,7 @@ def extract(src_file_path: Path, dst_dir_path: Optional[Path] = None, skip_if_ex if only_item is None: only_item = item # Else, if only_item is not None, there are more than one items in the root of the directory. - # This means hat there is no 'only_item' and we can stop the loop + # This means that there is no 'only_item' and we can stop the loop else: only_item = None break From 94f2c3688acfbe96976f4c0261609a79e616abc8 Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 10:01:32 +0100 Subject: [PATCH 5/9] Move only-item logic to a separate function _get_only_item_in_dir() Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/zip.py | 40 +++++++++++++++++++++------- 1 file changed, 30 insertions(+), 10 deletions(-) diff --git a/src/power_grid_model_io/utils/zip.py b/src/power_grid_model_io/utils/zip.py index 0d5622fb..fa188e6c 100644 --- a/src/power_grid_model_io/utils/zip.py +++ b/src/power_grid_model_io/utils/zip.py @@ -63,17 +63,37 @@ def extract(src_file_path: Path, dst_dir_path: Optional[Path] = None, skip_if_ex # Zip files often contain a single directory with the same name as the zip file. # In that case, return the dir to that directory instead of the root dir - only_item: Optional[Path] = None - for item in dst_dir_path.iterdir(): - # If only_item is None, this is the first iteration, so item may be the only item - if only_item is None: - only_item = item - # Else, if only_item is not None, there are more than one items in the root of the directory. - # This means that there is no 'only_item' and we can stop the loop - else: - only_item = None - break + only_item = _get_only_item_in_dir(dst_dir_path) if only_item and only_item.is_dir() and only_item.name == src_file_path.stem: dst_dir_path = only_item return dst_dir_path.resolve() + + +def _get_only_item_in_dir(dir_path: Path) -> Optional[Path]: + """ + If dir path contains only a single item, return that item. + Return None otherwise (if there are no items at all, or more than one item). + + Args: + dir_path: The path tho the directory + + Returns: + A path to the only item (dir or file) in the directory + """ + + only_item: Optional[Path] = None + for item in dir_path.iterdir(): + + # If only_item is not None at this point, it must have been set in the first iteration, i.e. there are more + # than one items in the directory, so return None. + if only_item is not None: + return None + + # Else, if only_item is None, we are in the first iteration, i.e. the first item in the dir. This item may be + # the only item in the dir, so let's remember it. + only_item = item + + # If we have come to this point, there were zero or one items in the directory. Return the path to that item (or + # None, the initial value). + return only_item From a0930d352869c997cfe8412bf11b405dd106b3f7 Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 10:40:33 +0100 Subject: [PATCH 6/9] Use unique key as a sub diretory if both a file name and unique key are supplied Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/download.py | 24 ++++++++++++++--------- tests/unit/utils/test_download.py | 16 ++++++++------- tests/utils.py | 2 ++ 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/src/power_grid_model_io/utils/download.py b/src/power_grid_model_io/utils/download.py index c25a9687..782fd8e6 100644 --- a/src/power_grid_model_io/utils/download.py +++ b/src/power_grid_model_io/utils/download.py @@ -214,20 +214,26 @@ def get_download_path( unique_key: A unique string that can be used to generate a filename (e.g. a url). """ - # If no file_name is given, generate a file name - if file_name is None: + # If no dir_path is given, use the system's designated folder for temporary files. + if dir_path is None: + dir_path = Path(tempfile.gettempdir()) + + # If no specific download path was given, we need to generate a unique key (based on the given unique key) + if file_name is None or unique_key is not None: if unique_key is None: - raise ValueError("Supply data in order to auto generate a download path.") + raise ValueError("Supply a unique key in order to auto generate a download path.") sha256 = hashlib.sha256() sha256.update(unique_key.encode()) - hash_str = base64.b64encode(sha256.digest()).decode("ascii") - hash_str = hash_str.replace("/", "_").replace("+", "-").rstrip("=") - file_name = Path(f"{hash_str}.download") + unique_key = base64.b64encode(sha256.digest()).decode("ascii") + unique_key = unique_key.replace("/", "_").replace("+", "-").rstrip("=") - # If no dir_path is given, use the system's designated folder for temporary files - elif dir_path is None: - dir_path = Path(tempfile.gettempdir()) + # If no file name was given, use the unique key as a file name + if file_name is None: + file_name = Path(f"{unique_key}.download") + # Otherwise, use the unique key as a sub directory + else: + dir_path /= unique_key # Combine the two paths assert file_name is not None diff --git a/tests/unit/utils/test_download.py b/tests/unit/utils/test_download.py index 7e886703..370eba49 100644 --- a/tests/unit/utils/test_download.py +++ b/tests/unit/utils/test_download.py @@ -22,6 +22,10 @@ Response = namedtuple("Response", ["status", "headers"]) +# The base64 representation of the sha256 hash of "foo" is LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564= +# The / and + will be replaced with a _ and - character and the trailing = character(s) will be removed. +FOO_KEY = "LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564" + @pytest.fixture() def temp_dir(): @@ -356,7 +360,7 @@ def test_get_response_info__no_length(mock_urlopen): def test_get_download_path(temp_dir: Path): # Act - path = get_download_path(dir_path=temp_dir, file_name="file_name.zip", unique_key="foo") + path = get_download_path(dir_path=temp_dir, file_name="file_name.zip") # Assert assert path == temp_dir / "file_name.zip" @@ -364,17 +368,15 @@ def test_get_download_path(temp_dir: Path): def test_get_download_path__auto_dir(): # Act - path = get_download_path(file_name="file_name.zip") + path = get_download_path(file_name="file_name.zip", unique_key="foo") # Assert - assert path == Path(tempfile.gettempdir()).resolve() / "file_name.zip" + assert path == Path(tempfile.gettempdir()).resolve() / FOO_KEY / "file_name.zip" def test_get_download_path__auto_file_name(temp_dir: Path): # Arrange - # The base64 representation of the sha256 hash of "foo" is LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm564= - # The / and + will be replaced with a _ and - character and the trailing = character(s) will be removed. - expected_file_name = "LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564.download" + expected_file_name = f"{FOO_KEY}.download" # Act path = get_download_path(dir_path=temp_dir, unique_key="foo") @@ -385,7 +387,7 @@ def test_get_download_path__auto_file_name(temp_dir: Path): def test_get_download_path__missing_data(temp_dir: Path): # Act / Assert - with pytest.raises(ValueError, match=r"Supply data in order to auto generate a download path\."): + with pytest.raises(ValueError, match=r"Supply a unique key in order to auto generate a download path\."): get_download_path(dir_path=temp_dir) diff --git a/tests/utils.py b/tests/utils.py index a5ba999d..2d8a85da 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -178,6 +178,8 @@ def eq(left, right) -> bool: return False if isinstance(left, NDFrame): return (left == right).all() + if isinstance(right, NDFrame): + return False if isnan(left) and isnan(right): return True return left == right From 216f1edbf64950b3ab825972122309194688221d Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 10:47:48 +0100 Subject: [PATCH 7/9] Unit tests for _get_only_item_in_dir() Signed-off-by: Bram Stoeller --- tests/unit/utils/test_zip.py | 35 ++++++++++++++++++++++++++++++++++- 1 file changed, 34 insertions(+), 1 deletion(-) diff --git a/tests/unit/utils/test_zip.py b/tests/unit/utils/test_zip.py index dae83f58..e1376075 100644 --- a/tests/unit/utils/test_zip.py +++ b/tests/unit/utils/test_zip.py @@ -9,7 +9,7 @@ import pytest import structlog.testing -from power_grid_model_io.utils.zip import extract +from power_grid_model_io.utils.zip import _get_only_item_in_dir, extract from ...utils import MockTqdm, assert_log_exists @@ -123,3 +123,36 @@ def test_extract__return_subdir_path(mock_tqdm: MagicMock, temp_dir: Path): # Assert assert extract_dir_path == temp_dir / "foo" / "foo" assert (temp_dir / "foo" / "foo" / "foo.txt").is_file() + + +def test_get_only_item_in_dir__no_items(temp_dir): + # Act / Assert + assert _get_only_item_in_dir(temp_dir) == None + + +def test_get_only_item_in_dir__one_file(temp_dir): + # Arrange + with open(temp_dir / "file.txt", "wb"): + pass + + # Act / Assert + assert _get_only_item_in_dir(temp_dir) == temp_dir / "file.txt" + + +def test_get_only_item_in_dir__one_dir(temp_dir): + # Arrange + (temp_dir / "subdir").mkdir() + + # Act / Assert + assert _get_only_item_in_dir(temp_dir) == temp_dir / "subdir" + + +def test_get_only_item_in_dir__two_files(temp_dir): + # Arrange + with open(temp_dir / "file_1.txt", "wb"): + pass + with open(temp_dir / "file_2.txt", "wb"): + pass + + # Act / Assert + assert _get_only_item_in_dir(temp_dir) == None From 44c764a62bde0db7fd711d20aad72ff76ba269ac Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 11:01:16 +0100 Subject: [PATCH 8/9] Only use unique key for dirs if no dir_path was given Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/download.py | 13 +++++++------ tests/unit/utils/test_download.py | 8 ++++++++ 2 files changed, 15 insertions(+), 6 deletions(-) diff --git a/src/power_grid_model_io/utils/download.py b/src/power_grid_model_io/utils/download.py index 782fd8e6..631a9ffc 100644 --- a/src/power_grid_model_io/utils/download.py +++ b/src/power_grid_model_io/utils/download.py @@ -214,12 +214,8 @@ def get_download_path( unique_key: A unique string that can be used to generate a filename (e.g. a url). """ - # If no dir_path is given, use the system's designated folder for temporary files. - if dir_path is None: - dir_path = Path(tempfile.gettempdir()) - # If no specific download path was given, we need to generate a unique key (based on the given unique key) - if file_name is None or unique_key is not None: + if dir_path is None or file_name is None: if unique_key is None: raise ValueError("Supply a unique key in order to auto generate a download path.") @@ -233,7 +229,12 @@ def get_download_path( file_name = Path(f"{unique_key}.download") # Otherwise, use the unique key as a sub directory else: - dir_path /= unique_key + assert dir_path is None # sanity check + dir_path = Path(tempfile.gettempdir()) / unique_key + + # If no dir_path is given, use the system's designated folder for temporary files. + if dir_path is None: + dir_path = Path(tempfile.gettempdir()) # Combine the two paths assert file_name is not None diff --git a/tests/unit/utils/test_download.py b/tests/unit/utils/test_download.py index 370eba49..98d74eb3 100644 --- a/tests/unit/utils/test_download.py +++ b/tests/unit/utils/test_download.py @@ -366,6 +366,14 @@ def test_get_download_path(temp_dir: Path): assert path == temp_dir / "file_name.zip" +def test_get_download_path__ignore_unique_key(temp_dir: Path): + # Act + path = get_download_path(dir_path=temp_dir, file_name="file_name.zip", unique_key="foo") + + # Assert + assert path == temp_dir / "file_name.zip" + + def test_get_download_path__auto_dir(): # Act path = get_download_path(file_name="file_name.zip", unique_key="foo") From c9a25c8e7afb8832d9d89aee61a6ec6131af325f Mon Sep 17 00:00:00 2001 From: Bram Stoeller Date: Tue, 7 Mar 2023 11:14:37 +0100 Subject: [PATCH 9/9] Fix auto_dir bug Signed-off-by: Bram Stoeller --- src/power_grid_model_io/utils/download.py | 5 ++--- tests/unit/utils/test_download.py | 12 +++++++++++- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/power_grid_model_io/utils/download.py b/src/power_grid_model_io/utils/download.py index 631a9ffc..88731aeb 100644 --- a/src/power_grid_model_io/utils/download.py +++ b/src/power_grid_model_io/utils/download.py @@ -215,7 +215,7 @@ def get_download_path( """ # If no specific download path was given, we need to generate a unique key (based on the given unique key) - if dir_path is None or file_name is None: + if file_name is None or unique_key is not None: if unique_key is None: raise ValueError("Supply a unique key in order to auto generate a download path.") @@ -228,8 +228,7 @@ def get_download_path( if file_name is None: file_name = Path(f"{unique_key}.download") # Otherwise, use the unique key as a sub directory - else: - assert dir_path is None # sanity check + elif dir_path is None: dir_path = Path(tempfile.gettempdir()) / unique_key # If no dir_path is given, use the system's designated folder for temporary files. diff --git a/tests/unit/utils/test_download.py b/tests/unit/utils/test_download.py index 98d74eb3..8db11c3d 100644 --- a/tests/unit/utils/test_download.py +++ b/tests/unit/utils/test_download.py @@ -26,6 +26,8 @@ # The / and + will be replaced with a _ and - character and the trailing = character(s) will be removed. FOO_KEY = "LCa0a2j_xo_5m0U8HTBBNBNCLXBkg7-g-YpeiGJm564" +TEMP_DIR = Path(tempfile.gettempdir()).resolve() + @pytest.fixture() def temp_dir(): @@ -374,12 +376,20 @@ def test_get_download_path__ignore_unique_key(temp_dir: Path): assert path == temp_dir / "file_name.zip" +def test_get_download_path__temp_dir(): + # Act + path = get_download_path(file_name="file_name.zip") + + # Assert + assert path == TEMP_DIR / "file_name.zip" + + def test_get_download_path__auto_dir(): # Act path = get_download_path(file_name="file_name.zip", unique_key="foo") # Assert - assert path == Path(tempfile.gettempdir()).resolve() / FOO_KEY / "file_name.zip" + assert path == TEMP_DIR / FOO_KEY / "file_name.zip" def test_get_download_path__auto_file_name(temp_dir: Path):