From 1966a570935bb58bf00c87549aee826b56bf21b0 Mon Sep 17 00:00:00 2001 From: ltrotter Date: Wed, 11 Dec 2024 17:03:31 +0100 Subject: [PATCH] fix(mem): minor fixes to use memory datasets --- data/io_utils.py | 4 +++- data/memory_dataset.py | 8 ++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/data/io_utils.py b/data/io_utils.py index 83db399..c6309b3 100644 --- a/data/io_utils.py +++ b/data/io_utils.py @@ -87,7 +87,9 @@ def write_to_file(data, path, format: Optional[str] = None, append = False) -> N if format is None: format = get_format_from_path(path) - os.makedirs(os.path.dirname(path), exist_ok = True) + dir = os.path.dirname(path) + if len(dir) > 0: + os.makedirs(os.path.dirname(path), exist_ok = True) if not os.path.exists(path): append = False diff --git a/data/memory_dataset.py b/data/memory_dataset.py index 8a8f95c..ef6ac97 100644 --- a/data/memory_dataset.py +++ b/data/memory_dataset.py @@ -32,7 +32,7 @@ def _read_data(self, input_key): else: return self.data_dict.pop(input_key) - def _write_data(self, output: xr.DataArray|pd.DataFrame, output_key: str): + def _write_data(self, output: xr.DataArray|pd.DataFrame, output_key: str, **kwargs): self.data_dict[output_key] = output def _rm_data(self, key): @@ -40,7 +40,11 @@ def _rm_data(self, key): ## METHODS TO CHECK DATA AVAILABILITY def _check_data(self, data_path) -> bool: - return data_path in self.data_dict + for key in self.data_dict.keys(): + if key.startswith(data_path): + return True + else: + return False def _walk(self, prefix): for key in self.data_dict.keys():