diff --git a/.github/workflows/build-test.yaml b/.github/workflows/build-test.yaml index f4b217a..8708cdb 100644 --- a/.github/workflows/build-test.yaml +++ b/.github/workflows/build-test.yaml @@ -221,6 +221,8 @@ jobs: max_attempts: 5 command: | python -c "import geant4_python_application as g4; g4.install_datasets()" + # Check that the datasets are installed correctly + python -c "import geant4_python_application as g4; g4.files.datasets.check_datasets(throw=True)" - name: Run tests run: | diff --git a/src/geant4_python_application/files/datasets.py b/src/geant4_python_application/files/datasets.py index a7a6095..05ee991 100644 --- a/src/geant4_python_application/files/datasets.py +++ b/src/geant4_python_application/files/datasets.py @@ -17,11 +17,9 @@ url = "https://cern.ch/geant4-data/datasets" -# It is discouraged to use the package directory to store data +# It is discouraged to use the python package directory to store data # data_dir = os.path.join(os.path.dirname(__file__), "geant4/data") # another idea is to use 'platformdirs' to store data in a platform-specific location - - def data_directory() -> str: return os.path.join( geant4_python_application.application_directory(), @@ -32,7 +30,11 @@ def data_directory() -> str: ) -# the datasets versions should be updated with each Geant4 version +def geant4_data_directory() -> str: + return os.environ["GEANT4_DATA_DIR"] + + +# the datasets versions should be updated with each Geant4 version (remember to update the checksum too!) # https://geant4.web.cern.ch/download/11.2.2.html#datasets Dataset = namedtuple("Dataset", ["name", "version", "filename", "env", "md5sum"]) @@ -42,7 +44,7 @@ def data_directory() -> str: version="4.7.1", filename="G4NDL", env="G4NEUTRONHPDATA", - md5sum="b001a2091bf9392e6833830347672ea2", + md5sum="54f0ed3995856f02433d42ec96d70bc6", ), Dataset( name="G4EMLOW", @@ -155,6 +157,9 @@ def _download_extract_dataset(dataset: Dataset, pbar: tqdm): for chunk in iter(lambda: f.read(chunk_size), b""): md5.update(chunk) if md5.hexdigest() != dataset.md5sum: + print( + f"MD5 checksum mismatch for {filename}: got {md5.hexdigest()} but expected {dataset.md5sum}" + ) raise RuntimeError(f"MD5 checksum mismatch for {filename}") f.seek(0) @@ -162,23 +167,50 @@ def _download_extract_dataset(dataset: Dataset, pbar: tqdm): tar.extractall(data_directory()) -def install_datasets(force: bool = False, show_progress: bool = True): - os.environ["GEANT4_DATA_DIR"] = data_directory() +def missing_datasets(directory: str | None = None) -> list[Dataset]: + if directory is None: + directory = data_directory() datasets_to_download = [] + for dataset in datasets: - path = os.path.join(data_directory(), dataset.name + dataset.version) - os.environ[dataset.env] = path - if not os.path.exists(path) or force: + path = os.path.join(directory, dataset.name + dataset.version) + if not os.path.exists(path): datasets_to_download.append(dataset) - - if len(datasets_to_download) == 0: + return datasets_to_download + + +def check_datasets(throw: bool = False) -> bool: + datasets_to_download = missing_datasets() + if datasets_to_download: + if throw: + raise RuntimeError( + f"Missing Geant4 datasets: {', '.join([f'{dataset.name}@v{dataset.version}' for dataset in datasets_to_download])}" + ) + return False + return True + + +def install_datasets(show_progress: bool = True): + # first try to see if the datasets are installed in the application directory + datasets_to_download = missing_datasets() + if not datasets_to_download: + # datasets are installed in application directory + os.environ["GEANT4_DATA_DIR"] = data_directory() return + # check if the datasets are present in the corresponding Geant4 directory + if not bool(missing_datasets(os.environ["GEANT4_DATA_DIR"])): + # return + ... + + # download and extract the datasets + os.environ["GEANT4_DATA_DIR"] = data_directory() + os.makedirs(data_directory(), exist_ok=True) if show_progress: print( f""" -Geant4 datasets (<2GB) will be installed to "{data_directory()}". +Geant4 datasets will be installed to "{data_directory()}". This may take a while but only needs to be done once. You can override the default location by calling `application_directory(path)` or `application_directory(temp=True)` to use a temporary directory. The following Geant4 datasets will be installed: {", ".join([f"{dataset.name}@v{dataset.version}" for dataset in datasets_to_download])}""" @@ -196,7 +228,7 @@ def install_datasets(force: bool = False, show_progress: bool = True): ) as executor: futures = [ executor.submit(_download_extract_dataset, dataset, pbar) - for i, dataset in enumerate(datasets_to_download) + for dataset in datasets_to_download ] concurrent.futures.wait(futures) @@ -212,7 +244,6 @@ def uninstall_datasets(): package_dir = os.path.dirname(__file__) if not os.path.relpath(package_dir, dir_to_remove).startswith(".."): - # make sure we don't accidentally delete something important raise RuntimeError( f"Refusing to remove {dir_to_remove} because it is not a subdirectory of {package_dir}" ) @@ -221,4 +252,4 @@ def uninstall_datasets(): def reinstall_datasets(): uninstall_datasets() - install_datasets(force=True) + install_datasets()