diff --git a/python/snewpy/_model_downloader.py b/python/snewpy/_model_downloader.py index 9b8bbfe0..dea01a01 100644 --- a/python/snewpy/_model_downloader.py +++ b/python/snewpy/_model_downloader.py @@ -21,10 +21,11 @@ from snewpy import __version__ as snewpy_version import logging -logger = logging.getLogger('FileHandle') +logger = logging.getLogger("FileHandle") -def _md5(fname:str) -> str: + +def _md5(fname: str) -> str: """calculate the md5sum hash of a file.""" hash_md5 = hashlib.md5() with open(fname, "rb") as f: @@ -33,62 +34,70 @@ def _md5(fname:str) -> str: return hash_md5.hexdigest() -def _download(src:str, dest:str, chunk_size=8192): +def _download(src: str, dest: str, chunk_size=8192): """Download a file from 'src' to 'dest' and show the progress bar.""" - #make sure parent dir exists + # make sure parent dir exists Path(dest).parent.mkdir(exist_ok=True, parents=True) with requests.get(src, stream=True) as r: r.raise_for_status() - fileSize = int(r.headers.get('content-length', 0)) - with tqdm(desc=dest.name, total=fileSize, unit='iB', unit_scale=True, unit_divisor=1024) as bar: - with open(dest, 'wb') as f: + fileSize = int(r.headers.get("content-length", 0)) + with tqdm( + desc=dest.name, + total=fileSize, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + with open(dest, "wb") as f: for chunk in r.iter_content(chunk_size=chunk_size): size = f.write(chunk) bar.update(size) class ChecksumError(FileNotFoundError): - """Raise an exception due to a mismatch in the MD5 checksum. - """ + """Raise an exception due to a mismatch in the MD5 checksum.""" + def __init__(self, path, md5_exp, md5_actual): - super().__init__(f'Checksum error for file {path}: {md5_actual}!={md5_exp}') + super().__init__(f"Checksum error for file {path}: {md5_actual}!={md5_exp}") + pass class MissingFileError(FileNotFoundError): - """Raise an exception due to a missing file. - """ + """Raise an exception due to a missing file.""" + pass @dataclass class FileHandle: """Object storing local path, remote URL (optional), and MD5 sum -(optional) for a SNEWPY model file. If the requested file is already present -locally, open it. Otherwise, download it from the remote URL to the desired -local path. + (optional) for a SNEWPY model file. If the requested file is already present + locally, open it. Otherwise, download it from the remote URL to the desired + local path. """ + path: Path remote: str = None md5: Optional[str] = None - + def check(self) -> None: """Check if the given file exists locally and has a correct md5 sum. Raises ------ :class:`MissingFileError` if the local copy of the file is missing - :class:`ChecksumError` + :class:`ChecksumError` if the local file exists, but the checksum is wrong""" if not self.path.exists(): raise MissingFileError(self.path) if self.md5: - logger.info(f'File {self.path}: checking md5') + logger.info(f"File {self.path}: checking md5") md5 = _md5(self.path) - logger.debug(f'{md5} vs expected {self.md5}') - if (md5 != self.md5): + logger.debug(f"{md5} vs expected {self.md5}") + if md5 != self.md5: raise ChecksumError(self.path, self.md5, md5) - + def load(self) -> Path: """Make sure that local file exists and has a correct checksum. Download the file if needed. @@ -96,45 +105,106 @@ def load(self) -> Path: try: self.check() except FileNotFoundError as e: - logger.info(f'Downloading file {self.path}') + logger.info(f"Downloading file {self.path}") _download(self.remote, self.path) self.check() return self.path -def from_zenodo(zenodo_id:str, model:str, filename:str): +def _from_zenodo(zenodo_id: str, filename: str): """Access files on Zenodo. Parameters ---------- zenodo_id : Zenodo record for model files. - model : Name of the model class for this model file. filename : Expected filename storing simulation data. Returns ------- file_url, md5sum """ - zenodo_url = f'https://zenodo.org/api/records/{zenodo_id}' + zenodo_url = f"https://zenodo.org/api/records/{zenodo_id}" record = requests.get(zenodo_url).json() - # Search for model file string in Zenodo request for this record. - file = next((_file for _file in record['files'] if _file['key'] == filename), None) + # Search for file string in Zenodo request for this record. + file = next((_file for _file in record["files"] if _file["key"] == filename), None) # If matched, return a tuple of URL and checksum.Otherwise raise an exception. if file is not None: - return file['links']['self'], file['checksum'].split(':')[1] + return file["links"]["self"], file["checksum"].split(":")[1] else: raise MissingFileError(filename) -def get_model_data(model: str, filename: str, path: str = model_path) -> Path: + +class ModelRegistry: """Access model data. Configuration for each model is in a YAML file distributed with SNEWPY. + """ + + def __init__(self, config_file: Path = None, local_path: str = model_path): + """ + Parameters + ---------- + config_file: YAML configuration file. If None (default) use the 'model_files.yml' from SNEWPY resources + local_path: local installation path (defaults to astropy cache). + """ + if config_file is None: + context = open_text("snewpy.models", "model_files.yml") + else: + context = open(config_file) + with context as f: + self.config = yaml.safe_load(f) + self.models = self.config["models"] + self.local_path = local_path + + def get_file(self, config_path: str, filename: str) -> Path: + """Get the requested data file from the models file repository + + Parameters + ---------- + config_path : dot-separated path of the model in the YAML configuration (e.g. "ccsn.Bollig_2016") + filename : Name of simulation datafile, or a relative path from the model sub-directory + + Returns + ------- + Path of downloaded file. + """ + tokens = config_path.split(".") + config = self.models + for t in tokens: + config = config[t] + # store the model name + model = tokens[-1] + # Get data from GitHub or Zenodo. + repo = config["repository"] + if repo == "zenodo": + url, md5 = _from_zenodo(zenodo_id=config["zenodo_id"], filename=filename) + else: + # format the url directly + params = { + "model": model, + "filename": filename, + "snewpy_version": snewpy_version, + } + params.update( + config + ) # default parameters can be overriden in the model config + url, md5 = repo.format(**params), None + localpath = Path(model_path) / str(model) + localpath.mkdir(exist_ok=True, parents=True) + fh = FileHandle(path=localpath / filename, remote=url, md5=md5) + return fh.load() + + +registry = ModelRegistry() + + +def get_model_data(model: str, filename: str) -> Path: + """Get the requested data file from the models file repository Parameters ---------- - model : Name of the model class for this model file. - filename : Name of simulation datafile, or a relative path from the model sub-directory - path : Local installation path (defaults to astropy cache). + model : dot-separated path of the model in the YAML configuration (e.g. "ccsn.Bollig_2016") + filename : Absolute path to simulation datafile, or a relative path from the model sub-directory Returns ------- @@ -142,29 +212,5 @@ def get_model_data(model: str, filename: str, path: str = model_path) -> Path: """ if os.path.isabs(filename): return Path(filename) - - params = { 'model':model, 'filename':filename, 'snewpy_version':snewpy_version} - - # Parse YAML file with model repository configurations. - with open_text('snewpy.models', 'model_files.yml') as f: - config = yaml.safe_load(f) - models = config['models'] - # Search for model in YAML configuration. - if model in models.keys(): - # Get data from GitHub or Zenodo. - modconf = models[model] - repo = modconf.pop('repository') - if repo == 'zenodo': - params['zenodo_id'] = modconf['zenodo_id'] - url, md5 = from_zenodo(**params) - else: - #format the url directly - params.update(modconf) #default parameters can be overriden in the model config - url, md5 = repo.format(**params), None - localpath = Path(path)/str(model) - localpath.mkdir(exist_ok=True, parents=True) - fh = FileHandle(path = localpath/filename,remote = url, md5=md5) - return fh.load() - else: - raise KeyError(f'No configuration for {model}') - + else: + return registry.get_file(model, filename)