diff --git a/src/geant4_python_application/__init__.py b/src/geant4_python_application/__init__.py index 57b3c25..57d835a 100644 --- a/src/geant4_python_application/__init__.py +++ b/src/geant4_python_application/__init__.py @@ -27,7 +27,7 @@ def _setup_manager(self, *args, **kwargs): - install_datasets() + install_datasets(show_progress=False) return self._setup_manager(*args, **kwargs) diff --git a/src/geant4_python_application/datasets.py b/src/geant4_python_application/datasets.py index e2f3ada..30cf7a2 100644 --- a/src/geant4_python_application/datasets.py +++ b/src/geant4_python_application/datasets.py @@ -3,6 +3,7 @@ import concurrent.futures import hashlib import os +import shutil import tarfile import tempfile from collections import namedtuple @@ -97,7 +98,9 @@ ) -def _download_extract_dataset(dataset: Dataset, progress_bar_position: int = 0): +def _download_extract_dataset( + dataset: Dataset, progress_bar_position: int = 0, show_progress: bool = True +): filename = dataset.filename urlpath = f"{url}/{filename}.{dataset.version}.tar.gz" r = requests.get(urlpath, stream=True) @@ -112,6 +115,7 @@ def _download_extract_dataset(dataset: Dataset, progress_bar_position: int = 0): unit_scale=True, desc=f"Downloading and Extracting {filename}", position=progress_bar_position, + disable=not show_progress, ) as pbar: for chunk in r.iter_content(chunk_size=chunk_size): f.write(chunk) @@ -130,7 +134,7 @@ def _download_extract_dataset(dataset: Dataset, progress_bar_position: int = 0): pbar.update(total_size) -def install_datasets(force: bool = False): +def install_datasets(force: bool = False, show_progress: bool = True): os.makedirs(data_dir, exist_ok=True) os.environ["GEANT4_DATA_DIR"] = data_dir @@ -148,7 +152,7 @@ def install_datasets(force: bool = False): max_workers=len(datasets_to_download) ) as executor: futures = [ - executor.submit(_download_extract_dataset, dataset, i) + executor.submit(_download_extract_dataset, dataset, i, show_progress) for i, dataset in enumerate(datasets_to_download) ] concurrent.futures.wait(futures) @@ -156,3 +160,14 @@ def install_datasets(force: bool = False): def reinstall_datasets(): install_datasets(force=True) + + +def uninstall_datasets(): + dir_to_remove = os.path.dirname(data_dir) + package_dir = os.path.dirname(__file__) + + if not os.path.relpath(package_dir, dir_to_remove).startswith(".."): + raise RuntimeError( + f"Refusing to remove {dir_to_remove} because it is not a subdirectory of {package_dir}" + ) + shutil.rmtree(dir_to_remove, ignore_errors=True)