Skip to content

Commit

Permalink
Merge pull request #273 from SNEWS2/Sheshuk/OOP_model_downloader
Browse files Browse the repository at this point in the history
Update model downloader
  • Loading branch information
JostMigenda authored Nov 24, 2023
2 parents 939e81b + c5650c5 commit 7a2176f
Showing 1 changed file with 104 additions and 58 deletions.
162 changes: 104 additions & 58 deletions python/snewpy/_model_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@
from snewpy import __version__ as snewpy_version

import logging
logger = logging.getLogger('FileHandle')

logger = logging.getLogger("FileHandle")

def _md5(fname:str) -> str:

def _md5(fname: str) -> str:
"""calculate the md5sum hash of a file."""
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
Expand All @@ -33,138 +34,183 @@ def _md5(fname:str) -> str:
return hash_md5.hexdigest()


def _download(src:str, dest:str, chunk_size=8192):
def _download(src: str, dest: str, chunk_size=8192):
"""Download a file from 'src' to 'dest' and show the progress bar."""
#make sure parent dir exists
# make sure parent dir exists
Path(dest).parent.mkdir(exist_ok=True, parents=True)
with requests.get(src, stream=True) as r:
r.raise_for_status()
fileSize = int(r.headers.get('content-length', 0))
with tqdm(desc=dest.name, total=fileSize, unit='iB', unit_scale=True, unit_divisor=1024) as bar:
with open(dest, 'wb') as f:
fileSize = int(r.headers.get("content-length", 0))
with tqdm(
desc=dest.name,
total=fileSize,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as bar:
with open(dest, "wb") as f:
for chunk in r.iter_content(chunk_size=chunk_size):
size = f.write(chunk)
bar.update(size)


class ChecksumError(FileNotFoundError):
"""Raise an exception due to a mismatch in the MD5 checksum.
"""
"""Raise an exception due to a mismatch in the MD5 checksum."""

def __init__(self, path, md5_exp, md5_actual):
super().__init__(f'Checksum error for file {path}: {md5_actual}!={md5_exp}')
super().__init__(f"Checksum error for file {path}: {md5_actual}!={md5_exp}")

pass


class MissingFileError(FileNotFoundError):
"""Raise an exception due to a missing file.
"""
"""Raise an exception due to a missing file."""

pass


@dataclass
class FileHandle:
"""Object storing local path, remote URL (optional), and MD5 sum
(optional) for a SNEWPY model file. If the requested file is already present
locally, open it. Otherwise, download it from the remote URL to the desired
local path.
(optional) for a SNEWPY model file. If the requested file is already present
locally, open it. Otherwise, download it from the remote URL to the desired
local path.
"""

path: Path
remote: str = None
md5: Optional[str] = None

def check(self) -> None:
"""Check if the given file exists locally and has a correct md5 sum.
Raises
------
:class:`MissingFileError`
if the local copy of the file is missing
:class:`ChecksumError`
:class:`ChecksumError`
if the local file exists, but the checksum is wrong"""
if not self.path.exists():
raise MissingFileError(self.path)
if self.md5:
logger.info(f'File {self.path}: checking md5')
logger.info(f"File {self.path}: checking md5")
md5 = _md5(self.path)
logger.debug(f'{md5} vs expected {self.md5}')
if (md5 != self.md5):
logger.debug(f"{md5} vs expected {self.md5}")
if md5 != self.md5:
raise ChecksumError(self.path, self.md5, md5)

def load(self) -> Path:
"""Make sure that local file exists and has a correct checksum.
Download the file if needed.
"""
try:
self.check()
except FileNotFoundError as e:
logger.info(f'Downloading file {self.path}')
logger.info(f"Downloading file {self.path}")
_download(self.remote, self.path)
self.check()
return self.path


def from_zenodo(zenodo_id:str, model:str, filename:str):
def _from_zenodo(zenodo_id: str, filename: str):
"""Access files on Zenodo.
Parameters
----------
zenodo_id : Zenodo record for model files.
model : Name of the model class for this model file.
filename : Expected filename storing simulation data.
Returns
-------
file_url, md5sum
"""
zenodo_url = f'https://zenodo.org/api/records/{zenodo_id}'
zenodo_url = f"https://zenodo.org/api/records/{zenodo_id}"
record = requests.get(zenodo_url).json()
# Search for model file string in Zenodo request for this record.
file = next((_file for _file in record['files'] if _file['key'] == filename), None)
# Search for file string in Zenodo request for this record.
file = next((_file for _file in record["files"] if _file["key"] == filename), None)

# If matched, return a tuple of URL and checksum.Otherwise raise an exception.
if file is not None:
return file['links']['self'], file['checksum'].split(':')[1]
return file["links"]["self"], file["checksum"].split(":")[1]
else:
raise MissingFileError(filename)

def get_model_data(model: str, filename: str, path: str = model_path) -> Path:

class ModelRegistry:
"""Access model data. Configuration for each model is in a YAML file
distributed with SNEWPY.
"""

def __init__(self, config_file: Path = None, local_path: str = model_path):
"""
Parameters
----------
config_file: YAML configuration file. If None (default) use the 'model_files.yml' from SNEWPY resources
local_path: local installation path (defaults to astropy cache).
"""
if config_file is None:
context = open_text("snewpy.models", "model_files.yml")
else:
context = open(config_file)
with context as f:
self.config = yaml.safe_load(f)
self.models = self.config["models"]
self.local_path = local_path

def get_file(self, config_path: str, filename: str) -> Path:
"""Get the requested data file from the models file repository
Parameters
----------
config_path : dot-separated path of the model in the YAML configuration (e.g. "ccsn.Bollig_2016")
filename : Name of simulation datafile, or a relative path from the model sub-directory
Returns
-------
Path of downloaded file.
"""
tokens = config_path.split(".")
config = self.models
for t in tokens:
config = config[t]
# store the model name
model = tokens[-1]
# Get data from GitHub or Zenodo.
repo = config["repository"]
if repo == "zenodo":
url, md5 = _from_zenodo(zenodo_id=config["zenodo_id"], filename=filename)
else:
# format the url directly
params = {
"model": model,
"filename": filename,
"snewpy_version": snewpy_version,
}
params.update(
config
) # default parameters can be overriden in the model config
url, md5 = repo.format(**params), None
localpath = Path(model_path) / str(model)
localpath.mkdir(exist_ok=True, parents=True)
fh = FileHandle(path=localpath / filename, remote=url, md5=md5)
return fh.load()


registry = ModelRegistry()


def get_model_data(model: str, filename: str) -> Path:
"""Get the requested data file from the models file repository
Parameters
----------
model : Name of the model class for this model file.
filename : Name of simulation datafile, or a relative path from the model sub-directory
path : Local installation path (defaults to astropy cache).
model : dot-separated path of the model in the YAML configuration (e.g. "ccsn.Bollig_2016")
filename : Absolute path to simulation datafile, or a relative path from the model sub-directory
Returns
-------
Path of downloaded file.
"""
if os.path.isabs(filename):
return Path(filename)

params = { 'model':model, 'filename':filename, 'snewpy_version':snewpy_version}

# Parse YAML file with model repository configurations.
with open_text('snewpy.models', 'model_files.yml') as f:
config = yaml.safe_load(f)
models = config['models']
# Search for model in YAML configuration.
if model in models.keys():
# Get data from GitHub or Zenodo.
modconf = models[model]
repo = modconf.pop('repository')
if repo == 'zenodo':
params['zenodo_id'] = modconf['zenodo_id']
url, md5 = from_zenodo(**params)
else:
#format the url directly
params.update(modconf) #default parameters can be overriden in the model config
url, md5 = repo.format(**params), None
localpath = Path(path)/str(model)
localpath.mkdir(exist_ok=True, parents=True)
fh = FileHandle(path = localpath/filename,remote = url, md5=md5)
return fh.load()
else:
raise KeyError(f'No configuration for {model}')

else:
return registry.get_file(model, filename)

0 comments on commit 7a2176f

Please sign in to comment.