Skip to content

Commit

Permalink
fix datasets install (#71)
Browse files Browse the repository at this point in the history
* helper functions to check datasets

* Update datasets.py (#72)

* Update datasets.py

Signed-off-by: Kavya Wadhwa <72140379+kavyawadhwa134@users.noreply.github.com>

* style: pre-commit fixes

---------

Signed-off-by: Kavya Wadhwa <72140379+kavyawadhwa134@users.noreply.github.com>
Signed-off-by: Luis Antonio Obis Aparicio <35803280+lobis@users.noreply.github.com>
Co-authored-by: Kavya Wadhwa <72140379+kavyawadhwa134@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* restore checksum, fix bad checksum

* additional check for datasets

* method to get G4 data dir

* fix bad method name

---------

Signed-off-by: Kavya Wadhwa <72140379+kavyawadhwa134@users.noreply.github.com>
Signed-off-by: Luis Antonio Obis Aparicio <35803280+lobis@users.noreply.github.com>
Co-authored-by: Kavya Wadhwa <72140379+kavyawadhwa134@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Oct 6, 2024
1 parent 73a6eb1 commit dbdb909
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/build-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,8 @@ jobs:
max_attempts: 5
command: |
python -c "import geant4_python_application as g4; g4.install_datasets()"
# Check that the datasets are installed correctly
python -c "import geant4_python_application as g4; g4.files.datasets.check_datasets(throw=True)"
- name: Run tests
run: |
Expand Down
63 changes: 47 additions & 16 deletions src/geant4_python_application/files/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
url = "https://cern.ch/geant4-data/datasets"


# It is discouraged to use the package directory to store data
# It is discouraged to use the python package directory to store data
# data_dir = os.path.join(os.path.dirname(__file__), "geant4/data")
# another idea is to use 'platformdirs' to store data in a platform-specific location


def data_directory() -> str:
return os.path.join(
geant4_python_application.application_directory(),
Expand All @@ -32,7 +30,11 @@ def data_directory() -> str:
)


# the datasets versions should be updated with each Geant4 version
def geant4_data_directory() -> str:
return os.environ["GEANT4_DATA_DIR"]


# the datasets versions should be updated with each Geant4 version (remember to update the checksum too!)
# https://geant4.web.cern.ch/download/11.2.2.html#datasets
Dataset = namedtuple("Dataset", ["name", "version", "filename", "env", "md5sum"])

Expand All @@ -42,7 +44,7 @@ def data_directory() -> str:
version="4.7.1",
filename="G4NDL",
env="G4NEUTRONHPDATA",
md5sum="b001a2091bf9392e6833830347672ea2",
md5sum="54f0ed3995856f02433d42ec96d70bc6",
),
Dataset(
name="G4EMLOW",
Expand Down Expand Up @@ -155,30 +157,60 @@ def _download_extract_dataset(dataset: Dataset, pbar: tqdm):
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
if md5.hexdigest() != dataset.md5sum:
print(
f"MD5 checksum mismatch for {filename}: got {md5.hexdigest()} but expected {dataset.md5sum}"
)
raise RuntimeError(f"MD5 checksum mismatch for {filename}")

f.seek(0)
with tarfile.open(fileobj=f, mode="r:gz") as tar:
tar.extractall(data_directory())


def install_datasets(force: bool = False, show_progress: bool = True):
os.environ["GEANT4_DATA_DIR"] = data_directory()
def missing_datasets(directory: str | None = None) -> list[Dataset]:
if directory is None:
directory = data_directory()
datasets_to_download = []

for dataset in datasets:
path = os.path.join(data_directory(), dataset.name + dataset.version)
os.environ[dataset.env] = path
if not os.path.exists(path) or force:
path = os.path.join(directory, dataset.name + dataset.version)
if not os.path.exists(path):
datasets_to_download.append(dataset)

if len(datasets_to_download) == 0:
return datasets_to_download


def check_datasets(throw: bool = False) -> bool:
datasets_to_download = missing_datasets()
if datasets_to_download:
if throw:
raise RuntimeError(
f"Missing Geant4 datasets: {', '.join([f'{dataset.name}@v{dataset.version}' for dataset in datasets_to_download])}"
)
return False
return True


def install_datasets(show_progress: bool = True):
# first try to see if the datasets are installed in the application directory
datasets_to_download = missing_datasets()
if not datasets_to_download:
# datasets are installed in application directory
os.environ["GEANT4_DATA_DIR"] = data_directory()
return

# check if the datasets are present in the corresponding Geant4 directory
if not bool(missing_datasets(os.environ["GEANT4_DATA_DIR"])):
# return
...

# download and extract the datasets
os.environ["GEANT4_DATA_DIR"] = data_directory()

os.makedirs(data_directory(), exist_ok=True)
if show_progress:
print(
f"""
Geant4 datasets (<2GB) will be installed to "{data_directory()}".
Geant4 datasets will be installed to "{data_directory()}".
This may take a while but only needs to be done once.
You can override the default location by calling `application_directory(path)` or `application_directory(temp=True)` to use a temporary directory.
The following Geant4 datasets will be installed: {", ".join([f"{dataset.name}@v{dataset.version}" for dataset in datasets_to_download])}"""
Expand All @@ -196,7 +228,7 @@ def install_datasets(force: bool = False, show_progress: bool = True):
) as executor:
futures = [
executor.submit(_download_extract_dataset, dataset, pbar)
for i, dataset in enumerate(datasets_to_download)
for dataset in datasets_to_download
]
concurrent.futures.wait(futures)

Expand All @@ -212,7 +244,6 @@ def uninstall_datasets():
package_dir = os.path.dirname(__file__)

if not os.path.relpath(package_dir, dir_to_remove).startswith(".."):
# make sure we don't accidentally delete something important
raise RuntimeError(
f"Refusing to remove {dir_to_remove} because it is not a subdirectory of {package_dir}"
)
Expand All @@ -221,4 +252,4 @@ def uninstall_datasets():

def reinstall_datasets():
uninstall_datasets()
install_datasets(force=True)
install_datasets()

0 comments on commit dbdb909

Please sign in to comment.