Skip to content

Commit

Permalink
download datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
lobis committed Dec 6, 2023
1 parent 6f151ab commit 58cd135
Show file tree
Hide file tree
Showing 3 changed files with 170 additions and 114 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ classifiers = [
]
dependencies = [
"awkward",
"numpy"
"numpy",
"tqdm",
]
description = "Geant4 Python Application"
name = "geant4_python_application"
Expand Down
123 changes: 10 additions & 113 deletions src/geant4_python_application/__init__.py
Original file line number Diff line number Diff line change
@@ -1,112 +1,6 @@
from __future__ import annotations

import ctypes
import os
import shutil
import subprocess
import sys
import warnings

geant4_config = shutil.which("geant4-config")

if geant4_config is None:
raise ImportError(
"Could not find geant4-config. Please make sure Geant4 is installed and sourced."
)

_geant4_version = subprocess.check_output(
[geant4_config, "--version"], encoding="utf-8"
).strip()

_geant4_libs = subprocess.check_output(
[geant4_config, "--libs"], encoding="utf-8"
).strip()

_geant4_prefix = subprocess.check_output(
[geant4_config, "--prefix"], encoding="utf-8"
).strip()

_geant4_libs_dir = os.path.join(_geant4_prefix, "lib")


def load_libs():
for lib in _geant4_libs.split():
if not lib.startswith("-l"):
continue
# lib has format -l<libname>, transform to file name which is lib<libname>
library_absolute_path = os.path.join(_geant4_libs_dir, "lib" + lib[2:])
if sys.platform == "darwin":
library_absolute_path += ".dylib"
elif sys.platform == "linux":
library_absolute_path += ".so"
else:
raise RuntimeError("Unsupported platform: ", sys.platform)

if os.path.exists(library_absolute_path):
try:
ctypes.cdll.LoadLibrary(library_absolute_path)
except Exception as e:
warnings.warn(
f"Could not load shared library: {library_absolute_path}. Error: {e}"
)


def datasets() -> list[(str, str, str)]:
_datasets = subprocess.check_output(
[geant4_config, "--datasets"], encoding="utf-8"
).strip()
return [tuple(line.split()) for line in _datasets.split("\n") if line.strip() != ""]


def check_datasets() -> bool:
dataset_dirs = [dataset[2] for dataset in datasets()]
return all([os.path.exists(dataset_dir) for dataset_dir in dataset_dirs])


def install_datasets(always: bool = False) -> None:
if not always and check_datasets():
return
subprocess.run([geant4_config, "--install-datasets"], check=True)


_minimum_geant4_version = "11.1.0"


def _parse_geant4_version(version: str) -> (int, int, int):
return tuple([int(part) for part in version.split(".")])


try:
available_geant4_version_parsed = _parse_geant4_version(_geant4_version)
minimum_geant4_version_parsed = _parse_geant4_version(_minimum_geant4_version)
if available_geant4_version_parsed < minimum_geant4_version_parsed:
warnings.warn(
f"Geant4 version {_geant4_version} is lower than the minimum required version {_minimum_geant4_version}."
)
except Exception as e:
raise RuntimeError(
f"Error comparing Geant4 version '{_geant4_version}' with '{_minimum_geant4_version}': {e}"
)

try:
load_libs()
except Exception as e:
print(f"Could not load Geant4 dynamic libraries: {e}")

# environment variable GEANT4_DATA_DIR is a recent addition to Geant4
if "GEANT4_DATA_DIR" not in os.environ:
os.environ["GEANT4_DATA_DIR"] = os.path.join(
_geant4_prefix, "share", "Geant4", "data"
)

_geant4_data_dir = os.environ["GEANT4_DATA_DIR"]
if not os.path.exists(_geant4_data_dir) or not os.listdir(_geant4_data_dir):
warnings.warn(
f"""Geant4 data directory {_geant4_data_dir} does not exist or is empty. Please set the
GEANT4_DATA_DIR environment variable to the correct directory or install the datasets via `geant4-config
--install-datasets` or calling `geant4_python_application.install_datasets()`."""
)

from geant4_python_application.datasets import install_datasets
from geant4_python_application.gdml import basic_gdml
from geant4_python_application.geant4_application import (
Application,
Expand All @@ -119,12 +13,6 @@ def _parse_geant4_version(version: str) -> (int, int, int):
__version__,
)

if _geant4_version != __geant4_version__:
warnings.warn(
f"""Geant4 version mismatch. Python module was compiled with Geant4 version {__geant4_version__}
, but the available Geant4 version is {_geant4_version}."""
)

__all__ = [
"__doc__",
"__version__",
Expand All @@ -136,3 +24,12 @@ def _parse_geant4_version(version: str) -> (int, int, int):
"StackingAction",
"basic_gdml",
]


def _setup_manager(self, *args, **kwargs):
install_datasets()
self._setup_manager(*args, **kwargs)


Application._setup_manager = Application.setup_manager
Application.setup_manager = _setup_manager
158 changes: 158 additions & 0 deletions src/geant4_python_application/datasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from __future__ import annotations

import concurrent.futures
import hashlib
import os
import tarfile
import tempfile
from collections import namedtuple

import requests
import tqdm

url = "https://cern.ch/geant4-data/datasets"
data_dir = os.path.join(os.path.dirname(__file__), "geant4/data")

# https://github.com/HaarigerHarald/geant4_pybind/blob/9bc90bc7f93df0d4966f29c90ffed5655e8d5904/source/datainit.py
Dataset = namedtuple("Dataset", ["name", "version", "filename", "env", "md5sum"])

datasets = (
Dataset(
name="G4NDL",
version="4.7",
filename="G4NDL",
env="G4NEUTRONHPDATA",
md5sum="b001a2091bf9392e6833830347672ea2",
),
Dataset(
name="G4EMLOW",
version="8.2",
filename="G4EMLOW",
env="G4LEDATA",
md5sum="07773e57be3f6f2ebb744da5ed574f6d",
),
Dataset(
name="PhotonEvaporation",
version="5.7",
filename="G4PhotonEvaporation",
env="G4LEVELGAMMADATA",
md5sum="81ff27deb23af4aa225423e6b3a06b39",
),
Dataset(
name="RadioactiveDecay",
version="5.6",
filename="G4RadioactiveDecay",
env="G4RADIOACTIVEDATA",
md5sum="acc1dbeb87b6b708b2874ced729a3a8f",
),
Dataset(
name="G4PARTICLEXS",
version="4.0",
filename="G4PARTICLEXS",
env="G4PARTICLEXSDATA",
md5sum="d82a4d171d50f55864e28b6cd6f433c0",
),
Dataset(
name="G4PII",
version="1.3",
filename="G4PII",
env="G4PIIDATA",
md5sum="05f2471dbcdf1a2b17cbff84e8e83b37",
),
Dataset(
name="RealSurface",
version="2.2",
filename="G4RealSurface",
env="G4REALSURFACEDATA",
md5sum="ea8f1cfa8d8aafd64b71fb30b3e8a6d9",
),
Dataset(
name="G4SAIDDATA",
version="2.0",
filename="G4SAIDDATA",
env="G4SAIDXSDATA",
md5sum="d5d4e9541120c274aeed038c621d39da",
),
Dataset(
name="G4ABLA",
version="3.1",
filename="G4ABLA",
env="G4ABLADATA",
md5sum="180f1f5d937733b207f8d5677f76296e",
),
Dataset(
name="G4INCL",
version="1.0",
filename="G4INCL",
env="G4INCLDATA",
md5sum="85fe937b6df46d41814f07175d3f5b51",
),
Dataset(
name="G4ENSDFSTATE",
version="2.3",
filename="G4ENSDFSTATE",
env="G4ENSDFSTATEDATA",
md5sum="6f18fce8f217e7aaeaa3711be9b2c7bf",
),
)


def _download_extract_dataset(dataset: Dataset, progress_bar_position: int = 0):
filename = dataset.filename
urlpath = f"{url}/{filename}.{dataset.version}.tar.gz"
r = requests.get(urlpath, stream=True)
r.raise_for_status()

# Get the total file size for tqdm
total_size = int(r.headers.get("content-length", 0))
chunk_size = 4096
with tempfile.TemporaryFile() as f, tqdm.tqdm(
total=total_size,
unit="B",
unit_scale=True,
desc=f"Downloading and Extracting {filename}",
position=progress_bar_position,
) as pbar:
for chunk in r.iter_content(chunk_size=chunk_size):
f.write(chunk)
pbar.update(chunk_size)

f.seek(0)
md5 = hashlib.md5()
for chunk in iter(lambda: f.read(chunk_size), b""):
md5.update(chunk)
if md5.hexdigest() != dataset.md5sum:
raise RuntimeError(f"MD5 checksum mismatch for {filename}")

f.seek(0)
with tarfile.open(fileobj=f, mode="r:gz") as tar:
tar.extractall(data_dir)
pbar.update(total_size)


def install_datasets(force: bool = False):
os.makedirs(data_dir, exist_ok=True)

os.environ["GEANT4_DATA_DIR"] = data_dir
datasets_to_download = []
for dataset in datasets:
path = os.path.join(data_dir, dataset.name + dataset.version)
os.environ[dataset.env] = path
if not os.path.exists(path) or force:
datasets_to_download.append(dataset)

if len(datasets_to_download) == 0:
return

with concurrent.futures.ThreadPoolExecutor(
max_workers=len(datasets_to_download)
) as executor:
futures = [
executor.submit(_download_extract_dataset, dataset, i)
for i, dataset in enumerate(datasets_to_download)
]
concurrent.futures.wait(futures)


def reinstall_datasets():
install_datasets(force=True)

0 comments on commit 58cd135

Please sign in to comment.