diff --git a/requirements.txt b/requirements.txt index d311f78ac..1242e486e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ matplotlib>=3.1 NREL-rex>=0.2.82 NREL-phygnn>=0.0.23 -NREL-rev>=0.6.6 +NREL-rev<0.8.0 NREL-farms>=1.0.4 pytest>=5.2 pillow diff --git a/sup3r/bias/bias_transforms.py b/sup3r/bias/bias_transforms.py index b6484625b..6fe7165fd 100644 --- a/sup3r/bias/bias_transforms.py +++ b/sup3r/bias/bias_transforms.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Bias correction transformation functions.""" import logging +import os from warnings import warn import numpy as np @@ -10,6 +11,62 @@ logger = logging.getLogger(__name__) +def get_spatial_bc_factors(lat_lon, feature_name, bias_fp, threshold=0.1): + """Get bc factors (scalar/adder) for the given feature for the given + domain (specified by lat_lon). + + Parameters + ---------- + lat_lon : ndarray + Array of latitudes and longitudes for the domain to bias correct + (n_lats, n_lons, 2) + feature_name : str + Name of feature that is being corrected. Datasets with names + "{feature_name}_scalar" and "{feature_name}_adder" will be retrieved + from bias_fp. + bias_fp : str + Filepath to bias correction file from the bias calc module. Must have + datasets "{feature_name}_scalar" and "{feature_name}_adder" that are + the full low-resolution shape of the forward pass input that will be + sliced using lr_padded_slice for the current chunk. + threshold : float + Nearest neighbor euclidean distance threshold. If the coordinates are + more than this value away from the bias correction lat/lon, an error is + raised. + """ + dset_scalar = f'{feature_name}_scalar' + dset_adder = f'{feature_name}_adder' + with Resource(bias_fp) as res: + lat = np.expand_dims(res['latitude'], axis=-1) + lon = np.expand_dims(res['longitude'], axis=-1) + lat_lon_bc = np.dstack((lat, lon)) + diff = lat_lon_bc - lat_lon[:1, :1] + diff = np.hypot(diff[..., 0], diff[..., 1]) + idy, idx = np.where(diff == diff.min()) + slice_y = slice(idy[0], idy[0] + lat_lon.shape[0]) + slice_x = slice(idx[0], idx[0] + lat_lon.shape[1]) + + if diff.min() > threshold: + msg = ( + 'The DataHandler top left coordinate of {} ' + 'appears to be {} away from the nearest ' + 'bias correction coordinate of {} from {}. ' + 'Cannot apply bias correction.'.format( + lat_lon, + diff.min(), + lat_lon_bc[idy, idx], + os.path.basename(bias_fp), + ) + ) + logger.error(msg) + raise RuntimeError(msg) + + assert dset_scalar in res.dsets and dset_adder in res.dsets + scalar = res[dset_scalar, slice_y, slice_x] + adder = res[dset_adder, slice_y, slice_x] + return scalar, adder + + def global_linear_bc(input, scalar, adder, out_range=None): """Bias correct data using a simple global *scalar +adder method. @@ -37,8 +94,15 @@ def global_linear_bc(input, scalar, adder, out_range=None): return out -def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, - out_range=None, smoothing=0): +def local_linear_bc( + input, + lat_lon, + feature_name, + bias_fp, + lr_padded_slice, + out_range=None, + smoothing=0, +): """Bias correct data using a simple annual (or multi-year) *scalar +adder method on a site-by-site basis. @@ -47,6 +111,9 @@ def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, input : np.ndarray Sup3r input data to be bias corrected, assumed to be 3D with shape (spatial, spatial, temporal) for a single feature. + lat_lon : ndarray + Array of latitudes and longitudes for the domain to bias correct + (n_lats, n_lons, 2) feature_name : str Name of feature that is being corrected. Datasets with names "{feature_name}_scalar" and "{feature_name}_adder" will be retrieved @@ -77,12 +144,7 @@ def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, out = input * scalar + adder """ - scalar = f'{feature_name}_scalar' - adder = f'{feature_name}_adder' - with Resource(bias_fp) as res: - scalar = res[scalar] - adder = res[adder] - + scalar, adder = get_spatial_bc_factors(lat_lon, feature_name, bias_fp) # 3D bias correction factors have seasonal/monthly correction in last axis if len(scalar.shape) == 3 and len(adder.shape) == 3: scalar = scalar.mean(axis=-1) @@ -94,8 +156,10 @@ def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, adder = adder[spatial_slice] if np.isnan(scalar).any() or np.isnan(adder).any(): - msg = ('Bias correction scalar/adder values had NaNs for "{}" from: {}' - .format(feature_name, bias_fp)) + msg = ( + 'Bias correction scalar/adder values had NaNs for ' + f'"{feature_name}" from: {bias_fp}' + ) logger.warning(msg) warn(msg) @@ -107,12 +171,12 @@ def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, if smoothing > 0: for idt in range(scalar.shape[-1]): - scalar[..., idt] = gaussian_filter(scalar[..., idt], - smoothing, - mode='nearest') - adder[..., idt] = gaussian_filter(adder[..., idt], - smoothing, - mode='nearest') + scalar[..., idt] = gaussian_filter( + scalar[..., idt], smoothing, mode='nearest' + ) + adder[..., idt] = gaussian_filter( + adder[..., idt], smoothing, mode='nearest' + ) out = input * scalar + adder if out_range is not None: @@ -122,9 +186,17 @@ def local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, return out -def monthly_local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, - time_index, temporal_avg=True, out_range=None, - smoothing=0): +def monthly_local_linear_bc( + input, + lat_lon, + feature_name, + bias_fp, + lr_padded_slice, + time_index, + temporal_avg=True, + out_range=None, + smoothing=0, +): """Bias correct data using a simple monthly *scalar +adder method on a site-by-site basis. @@ -133,6 +205,9 @@ def monthly_local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, input : np.ndarray Sup3r input data to be bias corrected, assumed to be 3D with shape (spatial, spatial, temporal) for a single feature. + lat_lon : ndarray + Array of latitudes and longitudes for the domain to bias correct + (n_lats, n_lons, 2) feature_name : str Name of feature that is being corrected. Datasets with names "{feature_name}_scalar" and "{feature_name}_adder" will be retrieved @@ -172,12 +247,7 @@ def monthly_local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, out : np.ndarray out = input * scalar + adder """ - - scalar = f'{feature_name}_scalar' - adder = f'{feature_name}_adder' - with Resource(bias_fp) as res: - scalar = res[scalar] - adder = res[adder] + scalar, adder = get_spatial_bc_factors(lat_lon, feature_name, bias_fp) assert len(scalar.shape) == 3, 'Monthly bias correct needs 3D scalars' assert len(adder.shape) == 3, 'Monthly bias correct needs 3D adders' @@ -199,25 +269,29 @@ def monthly_local_linear_bc(input, feature_name, bias_fp, lr_padded_slice, scalar = np.repeat(scalar, input.shape[-1], axis=-1) adder = np.repeat(adder, input.shape[-1], axis=-1) if len(time_index.month.unique()) > 2: - msg = ('Bias correction method "monthly_local_linear_bc" was used ' - 'with temporal averaging over a time index with >2 months.') + msg = ( + 'Bias correction method "monthly_local_linear_bc" was used ' + 'with temporal averaging over a time index with >2 months.' + ) warn(msg) logger.warning(msg) if np.isnan(scalar).any() or np.isnan(adder).any(): - msg = ('Bias correction scalar/adder values had NaNs for "{}" from: {}' - .format(feature_name, bias_fp)) + msg = ( + 'Bias correction scalar/adder values had NaNs for ' + f'"{feature_name}" from: {bias_fp}' + ) logger.warning(msg) warn(msg) if smoothing > 0: for idt in range(scalar.shape[-1]): - scalar[..., idt] = gaussian_filter(scalar[..., idt], - smoothing, - mode='nearest') - adder[..., idt] = gaussian_filter(adder[..., idt], - smoothing, - mode='nearest') + scalar[..., idt] = gaussian_filter( + scalar[..., idt], smoothing, mode='nearest' + ) + adder[..., idt] = gaussian_filter( + adder[..., idt], smoothing, mode='nearest' + ) out = input * scalar + adder if out_range is not None: diff --git a/sup3r/pipeline/forward_pass.py b/sup3r/pipeline/forward_pass.py index 1467a0b48..c41c09795 100644 --- a/sup3r/pipeline/forward_pass.py +++ b/sup3r/pipeline/forward_pass.py @@ -849,6 +849,14 @@ def get_time_index(self, file_paths, max_workers=None, **kwargs): ---------- file_paths : list List of file paths for source data + max_workers : int | None + Number of workers to use to extract the time index from the given + files. This is used when a large number of single timestep netcdf + files is provided. + **kwargs : dict + Dictionary of kwargs passed to the resource handler opening the + given file_paths. For netcdf files this is xarray.open_mfdataset(). + For h5 files this is usually rex.Resource(). Returns ------- @@ -1061,7 +1069,8 @@ def __init__(self, strategy, chunk_index=0, node_index=0): self.data_handler.load_cached_data() self.input_data = self.data_handler.data - self.input_data = self.bias_correct_source_data(self.input_data) + self.input_data = self.bias_correct_source_data( + self.input_data, self.strategy.lr_lat_lon) exo_s_en = self.exo_kwargs.get('s_enhancements', None) out = self.pad_source_data(self.input_data, self.pad_width, @@ -1400,7 +1409,7 @@ def pad_source_data(input_data, pad_width, exo_data, return out, exo_data - def bias_correct_source_data(self, data): + def bias_correct_source_data(self, data, lat_lon): """Bias correct data using a method defined by the bias_correct_method input to ForwardPassStrategy @@ -1409,6 +1418,10 @@ def bias_correct_source_data(self, data): data : np.ndarray Any source data to be bias corrected, with the feature channel in the last axis. + lat_lon : np.ndarray + Latitude longitude array for the given data. Used to get the + correct bc factors for the appropriate domain. + (n_lats, n_lons, 2) Returns ------- @@ -1433,7 +1446,8 @@ def bias_correct_source_data(self, data): 'using function: {} with kwargs: {}' .format(feature, idf, method, feature_kwargs)) - data[..., idf] = method(data[..., idf], **feature_kwargs) + data[..., idf] = method(data[..., idf], lat_lon, + **feature_kwargs) return data @@ -1716,8 +1730,8 @@ def _single_proc_run(cls, strategy, node_index, chunk_index): @classmethod def run(cls, strategy, node_index): - """This routine runs forward passes on all spatiotemporal chunks for - the given node index. + """Runs forward passes on all spatiotemporal chunks for the given node + index. Parameters ---------- @@ -1738,8 +1752,8 @@ def run(cls, strategy, node_index): @classmethod def _run_serial(cls, strategy, node_index): - """This routine runs forward passes, on all spatiotemporal chunks for - the given node index, in serial. + """Runs forward passes, on all spatiotemporal chunks for the given node + index, in serial. Parameters ---------- @@ -1772,9 +1786,8 @@ def _run_serial(cls, strategy, node_index): @classmethod def _run_parallel(cls, strategy, node_index): - """This routine runs forward passes, on all spatiotemporal chunks for - the given node index, with data extraction and forward pass routines in - parallel. + """Runs forward passes, on all spatiotemporal chunks for the given node + index, with data extraction and forward pass routines in parallel. Parameters ---------- diff --git a/sup3r/postprocessing/collection.py b/sup3r/postprocessing/collection.py index 4d0cd14a6..4f27ee1dd 100644 --- a/sup3r/postprocessing/collection.py +++ b/sup3r/postprocessing/collection.py @@ -1,21 +1,21 @@ # -*- coding: utf-8 -*- """H5 file collection.""" -from concurrent.futures import as_completed, ThreadPoolExecutor +import glob import logging -import numpy as np import os -import pandas as pd -import psutil import time -import glob +from concurrent.futures import ThreadPoolExecutor, as_completed from warnings import warn -from scipy.spatial import KDTree -from rex.utilities.loggers import init_logger +import numpy as np +import pandas as pd +import psutil from rex.utilities.fun_utils import get_fun_call_str +from rex.utilities.loggers import init_logger +from scipy.spatial import KDTree from sup3r.pipeline import Status -from sup3r.postprocessing.file_handling import RexOutputs, OutputMixIn +from sup3r.postprocessing.file_handling import OutputMixIn, RexOutputs from sup3r.utilities import ModuleName from sup3r.utilities.cli import BaseCLI @@ -39,6 +39,7 @@ def __init__(self, file_paths): file_paths = glob.glob(file_paths) self.flist = sorted(file_paths) self.data = None + self.file_attrs = {} @classmethod def get_node_cmd(cls, config): @@ -51,36 +52,40 @@ def get_node_cmd(cls, config): run data collection. """ - import_str = ('from sup3r.postprocessing.collection ' - 'import Collector;\n' - 'from rex import init_logger;\n' - 'import time;\n' - 'from reV.pipeline.status import Status;\n') + import_str = ( + 'from sup3r.postprocessing.collection ' + 'import Collector;\n' + 'from rex import init_logger;\n' + 'import time;\n' + 'from reV.pipeline.status import Status;\n' + ) dc_fun_str = get_fun_call_str(cls.collect, config) log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"{dc_fun_str};\n" - "t_elap = time.time() - t0;\n" - ) + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"{dc_fun_str};\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_COLLECT, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') @classmethod - def get_slices(cls, final_time_index, final_meta, new_time_index, - new_meta): + def get_slices( + cls, final_time_index, final_meta, new_time_index, new_meta + ): """Get index slices where the new ti/meta belong in the final ti/meta. Parameters @@ -109,25 +114,44 @@ def get_slices(cls, final_time_index, final_meta, new_time_index, col_loc = np.where(final_meta['gid'].isin(new_meta['gid']))[0] if not len(row_loc) > 0: - msg = ('Could not find row locations in file collection. ' - 'New time index: {} final time index: {}' - .format(new_time_index, final_time_index)) + msg = ( + 'Could not find row locations in file collection. ' + 'New time index: {} final time index: {}'.format( + new_time_index, final_time_index + ) + ) logger.error(msg) raise RuntimeError(msg) if not len(col_loc) > 0: - msg = ('Could not find col locations in file collection. ' - 'New index: {} final index: {}' - .format(new_index, final_index)) + msg = ( + 'Could not find col locations in file collection. ' + 'New index: {} final index: {}'.format(new_index, final_index) + ) logger.error(msg) raise RuntimeError(msg) row_slice = slice(np.min(row_loc), np.max(row_loc) + 1) + col_slice = slice(np.min(col_loc), np.max(col_loc) + 1) + + msg = ( + f'row_slice={row_slice} conflict with row_indices={row_loc}. ' + 'Indices do not seem to be increasing and/or contiguous.' + ) + assert (row_slice.stop - row_slice.start) == len(row_loc), msg + + msg = ( + f'col_slice={col_slice} conflict with col_indices={col_loc}. ' + 'Indices do not seem to be increasing and/or contiguous.' + ) + check = (col_slice.stop - col_slice.start) == len(col_loc) + if not check: + logger.warning(msg) + warn(msg) return row_slice, col_loc - @classmethod - def get_coordinate_indices(cls, target_meta, full_meta, threshold=1e-4): + def get_coordinate_indices(self, target_meta, full_meta, threshold=1e-4): """Get coordindate indices in meta data for given targets Parameters @@ -139,18 +163,27 @@ def get_coordinate_indices(cls, target_meta, full_meta, threshold=1e-4): threshold : float Threshold distance for finding target coordinates within full meta """ - ll2 = np.vstack((full_meta.latitude.values, - full_meta.longitude.values)).T + ll2 = np.vstack( + (full_meta.latitude.values, full_meta.longitude.values) + ).T tree = KDTree(ll2) - - targets = np.vstack((target_meta.latitude.values, - target_meta.longitude.values)).T + targets = np.vstack( + (target_meta.latitude.values, target_meta.longitude.values) + ).T _, indices = tree.query(targets, distance_upper_bound=threshold) indices = indices[indices < len(full_meta)] return indices - def get_data(self, file_path, feature, time_index, meta, scale_factor, - dtype, threshold=1e-4): + def get_data( + self, + file_path, + feature, + time_index, + meta, + scale_factor, + dtype, + threshold=1e-4, + ): """Retreive a data array from a chunked file. Parameters @@ -161,7 +194,7 @@ def get_data(self, file_path, feature, time_index, meta, scale_factor, dataset to retrieve data from fpath. time_index : pd.Datetimeindex Time index of the final file. - final_meta : pd.DataFrame + meta : pd.DataFrame Meta data of the final file. scale_factor : int | float Final destination scale factor after collection. If the data @@ -189,29 +222,35 @@ def get_data(self, file_path, feature, time_index, meta, scale_factor, source_scale_factor = f.attrs[feature].get('scale_factor', 1) if feature not in f.dsets: - e = ('Trying to collect dataset "{}" but cannot find in ' - 'available: {}'.format(feature, f.dsets)) + e = ( + 'Trying to collect dataset "{}" but cannot find in ' + 'available: {}'.format(feature, f.dsets) + ) logger.error(e) raise KeyError(e) - mask = self.get_coordinate_indices(meta, f_meta, - threshold=threshold) + mask = self.get_coordinate_indices( + meta, f_meta, threshold=threshold + ) f_meta = f_meta.iloc[mask] f_data = f[feature][:, mask] if len(mask) == 0: - msg = ('No target coordinates found in masked meta. ' - f'Skipping collection for {file_path}.') + msg = ( + 'No target coordinates found in masked meta. ' + f'Skipping collection for {file_path}.' + ) logger.warning(msg) warn(msg) else: - row_slice, col_slice = Collector.get_slices(time_index, meta, - f_ti, f_meta) + row_slice, col_slice = Collector.get_slices( + time_index, meta, f_ti, f_meta + ) if scale_factor != source_scale_factor: f_data = f_data.astype(np.float32) - f_data *= (scale_factor / source_scale_factor) + f_data *= scale_factor / source_scale_factor if np.issubdtype(dtype, np.integer): f_data = np.round(f_data) @@ -219,17 +258,23 @@ def get_data(self, file_path, feature, time_index, meta, scale_factor, f_data = f_data.astype(dtype) self.data[row_slice, col_slice] = f_data - @staticmethod - def _get_file_attrs(file): + def _get_file_attrs(self, file): """Get meta data and time index for a single file""" - with RexOutputs(file, mode='r') as f: - meta = f.meta - time_index = f.time_index + if file in self.file_attrs: + meta = self.file_attrs[file]['meta'] + time_index = self.file_attrs[file]['time_index'] + else: + with RexOutputs(file, mode='r') as f: + meta = f.meta + time_index = f.time_index + if file not in self.file_attrs: + self.file_attrs[file] = {'meta': meta, 'time_index': time_index} return meta, time_index - @classmethod - def _get_collection_attrs_parallel(cls, file_paths, max_workers=None): - """Get meta data and time index from a file list to be collected. + def _get_collection_attrs( + self, file_paths, sort=True, sort_key=None, max_workers=None + ): + """Get important dataset attributes from a file list to be collected. Assumes the file list is chunked in time (row chunked). @@ -237,54 +282,141 @@ def _get_collection_attrs_parallel(cls, file_paths, max_workers=None): ---------- file_paths : list | str Explicit list of str file paths that will be sorted and collected - or a single string with unix-style /search/patt*ern.h5. Files - should have non-overlapping time_index dataset and fully - overlapping meta dataset. + or a single string with unix-style /search/patt*ern.h5. + sort : bool + flag to sort flist to determine meta data order. + sort_key : None | fun + Optional sort key to sort flist by (determines how meta is built + if out_file does not exist). max_workers : int | None Number of workers to use in parallel. 1 runs serial, None will use all available workers. + target_final_meta_file : str + Path to target final meta containing coordinates to keep from the + full list of coordinates present in the collected meta for the full + file list. + threshold : float + Threshold distance for finding target coordinates within full meta Returns ------- time_index : pd.datetimeindex - List of datetime indices for each file that is being collected + Concatenated full size datetime index from the flist that is + being collected meta : pd.DataFrame - List of meta data for each files that is being - collected + Concatenated full size meta data from the flist that is being + collected or provided target meta """ + if sort: + file_paths = sorted(file_paths, key=sort_key) + + logger.info( + 'Getting collection attrs for full dataset with ' + f'max_workers={max_workers}.' + ) time_index = [None] * len(file_paths) meta = [None] * len(file_paths) - futures = {} - with ThreadPoolExecutor(max_workers=max_workers) as exe: + if max_workers == 1: for i, fn in enumerate(file_paths): - future = exe.submit(cls._get_file_attrs, fn) - futures[future] = i + meta[i], time_index[i] = self._get_file_attrs(fn) + logger.debug(f'{i+1} / {len(file_paths)} files finished') + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i, fn in enumerate(file_paths): + future = exe.submit(self._get_file_attrs, fn) + futures[future] = i - interval = int(np.ceil(len(futures) / 10)) - for i, future in enumerate(as_completed(futures)): - if i % interval == 0: + for i, future in enumerate(as_completed(futures)): mem = psutil.virtual_memory() - logger.info('Meta collection futures completed: ' - '{0} out of {1}. ' - 'Current memory usage is ' - '{2:.3f} GB out of {3:.3f} GB total.' - .format(i + 1, len(futures), - mem.used / 1e9, mem.total / 1e9)) - try: - idx = futures[future] - meta[idx], time_index[idx] = future.result() - except Exception as e: - msg = ('Falied to get attrs from ' - f'{file_paths[futures[future]]}') - logger.exception(msg) - raise RuntimeError(msg) from e - return meta, time_index + msg = ( + f'Meta collection futures completed: {i + 1} out ' + f'of {len(futures)}. Current memory usage is ' + f'{mem.used / 1e9:.3f} GB out of ' + f'{mem.total / 1e9:.3f} GB total.' + ) + logger.info(msg) + try: + idx = futures[future] + meta[idx], time_index[idx] = future.result() + except Exception as e: + msg = ( + 'Falied to get attrs from ' + f'{file_paths[futures[future]]}' + ) + logger.exception(msg) + raise RuntimeError(msg) from e + time_index = pd.DatetimeIndex(np.concatenate(time_index)) + time_index = time_index.sort_values() + time_index = time_index.drop_duplicates() + meta = pd.concat(meta) - @classmethod - def _get_collection_attrs(cls, file_paths, sort=True, - sort_key=None, max_workers=None, - target_final_meta_file=None, threshold=1e-4): + if 'latitude' in meta and 'longitude' in meta: + meta = meta.drop_duplicates(subset=['latitude', 'longitude']) + meta = meta.sort_values('gid') + + return time_index, meta + + def get_target_and_masked_meta( + self, meta, target_final_meta_file=None, threshold=1e-4 + ): + """Use combined meta for all files and target_final_meta_file to get + mapping from the full meta to the target meta and the mapping from the + target meta to the full meta, both of which are masked to remove + coordinates not present in the target_meta. + + Parameters + ---------- + meta : pd.DataFrame + Concatenated full size meta data from the flist that is being + collected or provided target meta + target_final_meta_file : str + Path to target final meta containing coordinates to keep from the + full list of coordinates present in the collected meta for the full + file list. + threshold : float + Threshold distance for finding target coordinates within full meta + + Returns + ------- + target_final_meta : pd.DataFrame + Concatenated full size meta data from the flist that is being + collected or provided target meta + masked_meta : pd.DataFrame + Concatenated full size meta data from the flist that is being + collected masked against target_final_meta + """ + if target_final_meta_file is not None and os.path.exists( + target_final_meta_file + ): + target_final_meta = pd.read_csv(target_final_meta_file) + if 'gid' in target_final_meta.columns: + target_final_meta = target_final_meta.drop('gid', axis=1) + mask = self.get_coordinate_indices( + target_final_meta, meta, threshold=threshold + ) + masked_meta = meta.iloc[mask] + logger.info(f'Masked meta coordinates: {len(masked_meta)}') + mask = self.get_coordinate_indices( + masked_meta, target_final_meta, threshold=threshold + ) + target_final_meta = target_final_meta.iloc[mask] + logger.info(f'Target meta coordinates: {len(target_final_meta)}') + else: + target_final_meta = masked_meta = meta + + return target_final_meta, masked_meta + + def get_collection_attrs( + self, + file_paths, + sort=True, + sort_key=None, + max_workers=None, + target_final_meta_file=None, + threshold=1e-4, + ): """Get important dataset attributes from a file list to be collected. Assumes the file list is chunked in time (row chunked). @@ -327,53 +459,21 @@ def _get_collection_attrs(cls, file_paths, sort=True, that all the files in file_paths have the same global file attributes). """ + logger.info(f'Using target_final_meta_file={target_final_meta_file}') + if isinstance(target_final_meta_file, str): + msg = ( + f'Provided target meta ({target_final_meta_file}) does not ' + 'exist.' + ) + assert os.path.exists(target_final_meta_file), msg - if sort: - file_paths = sorted(file_paths, key=sort_key) - - logger.info('Getting collection attrs for full dataset') - - if max_workers == 1: - meta = [] - time_index = None - for i, fn in enumerate(file_paths): - with RexOutputs(fn, mode='r') as f: - meta.append(f.meta) - - if time_index is None: - time_index = f.time_index - else: - time_index = time_index.append(f.time_index) - logger.debug(f'{i+1} / {len(file_paths)} files finished') - else: - meta, time_index = cls._get_collection_attrs_parallel( - file_paths, max_workers=max_workers) - time_index = pd.DatetimeIndex(np.concatenate(time_index)) - - time_index = time_index.sort_values() - time_index = time_index.drop_duplicates() - meta = pd.concat(meta) - - if 'latitude' in meta and 'longitude' in meta: - meta = meta.drop_duplicates(subset=['latitude', 'longitude']) - - meta = meta.sort_values('gid') + time_index, meta = self._get_collection_attrs( + file_paths, sort=sort, sort_key=sort_key, max_workers=max_workers + ) - if (target_final_meta_file is not None - and os.path.exists(target_final_meta_file)): - target_final_meta = pd.read_csv(target_final_meta_file) - if 'gid' in target_final_meta.columns: - target_final_meta = target_final_meta.drop('gid', axis=1) - mask = cls.get_coordinate_indices(target_final_meta, meta, - threshold=threshold) - masked_meta = meta.iloc[mask] - logger.info(f'Masked meta coordinates: {len(masked_meta)}') - mask = cls.get_coordinate_indices(masked_meta, target_final_meta, - threshold=threshold) - target_final_meta = target_final_meta.iloc[mask] - logger.info(f'Target meta coordinates: {len(target_final_meta)}') - else: - target_final_meta = masked_meta = meta + target_final_meta, masked_meta = self.get_target_and_masked_meta( + meta, target_final_meta_file, threshold=threshold + ) shape = (len(time_index), len(target_final_meta)) @@ -382,9 +482,72 @@ def _get_collection_attrs(cls, file_paths, sort=True, return time_index, target_final_meta, masked_meta, shape, global_attrs - def _collect_flist(self, feature, subset_masked_meta, time_index, shape, - file_paths, out_file, target_masked_meta, - max_workers=None): + def _write_flist_data( + self, + out_file, + feature, + time_index, + subset_masked_meta, + target_masked_meta, + ): + """Write spatiotemporal file list data to output file for given + feature + + Parameters + ---------- + out_file : str + Name of output file + feature : str + Name of feature for output chunk + time_index : pd.DateTimeIndex + Time index for corresponding file list data + subset_masked_meta : pd.DataFrame + Meta for corresponding file list data + target_masked_meta : pd.DataFrame + Meta for full output file + """ + with RexOutputs(out_file, mode='r') as f: + target_ti = f.time_index + y_write_slice, x_write_slice = Collector.get_slices( + target_ti, + target_masked_meta, + time_index, + subset_masked_meta, + ) + Collector._ensure_dset_in_output(out_file, feature) + + with RexOutputs(out_file, mode='a') as f: + try: + f[feature, y_write_slice, x_write_slice] = self.data + except Exception as e: + msg = ( + f'Problem with writing data to {out_file} with ' + f't_slice={y_write_slice}, ' + f's_slice={x_write_slice}. {e}' + ) + logger.error(msg) + raise OSError(msg) from e + + logger.debug( + 'Finished writing "{}" for row {} and col {} to: {}'.format( + feature, + y_write_slice, + x_write_slice, + os.path.basename(out_file), + ) + ) + + def _collect_flist( + self, + feature, + subset_masked_meta, + time_index, + shape, + file_paths, + out_file, + target_masked_meta, + max_workers=None, + ): """Collect a dataset from a file list without getting attributes first. This file list can be a subset of a full file list to be collected. @@ -418,45 +581,76 @@ def _collect_flist(self, feature, subset_masked_meta, time_index, shape, attrs, final_dtype = self.get_dset_attrs(feature) scale_factor = attrs.get('scale_factor', 1) - logger.debug('Collecting file list of shape {}: {}' - .format(shape, file_paths)) + logger.debug( + 'Collecting file list of shape {}: {}'.format( + shape, file_paths + ) + ) self.data = np.zeros(shape, dtype=final_dtype) mem = psutil.virtual_memory() - logger.debug('Initializing output dataset "{0}" in-memory with ' - 'shape {1} and dtype {2}. Current memory usage is ' - '{3:.3f} GB out of {4:.3f} GB total.' - .format(feature, shape, final_dtype, - mem.used / 1e9, mem.total / 1e9)) + logger.debug( + 'Initializing output dataset "{}" in-memory with ' + 'shape {} and dtype {}. Current memory usage is ' + '{:.3f} GB out of {:.3f} GB total.'.format( + feature, + shape, + final_dtype, + mem.used / 1e9, + mem.total / 1e9, + ) + ) if max_workers == 1: for i, fname in enumerate(file_paths): - logger.debug('Collecting data from file {} out of {}.' - .format(i + 1, len(file_paths))) - self.get_data(fname, feature, time_index, - subset_masked_meta, scale_factor, - final_dtype) + logger.debug( + 'Collecting data from file {} out of {}.'.format( + i + 1, len(file_paths) + ) + ) + self.get_data( + fname, + feature, + time_index, + subset_masked_meta, + scale_factor, + final_dtype, + ) else: - logger.info('Running parallel collection on {} workers.' - .format(max_workers)) + logger.info( + 'Running parallel collection on {} workers.'.format( + max_workers + ) + ) futures = {} completed = 0 with ThreadPoolExecutor(max_workers=max_workers) as exe: for fname in file_paths: - future = exe.submit(self.get_data, fname, feature, - time_index, subset_masked_meta, - scale_factor, final_dtype) + future = exe.submit( + self.get_data, + fname, + feature, + time_index, + subset_masked_meta, + scale_factor, + final_dtype, + ) futures[future] = fname for future in as_completed(futures): completed += 1 mem = psutil.virtual_memory() - logger.info('Collection futures completed: ' - '{0} out of {1}. ' - 'Current memory usage is ' - '{2:.3f} GB out of {3:.3f} GB total.' - .format(completed, len(futures), - mem.used / 1e9, mem.total / 1e9)) + logger.info( + 'Collection futures completed: ' + '{} out of {}. ' + 'Current memory usage is ' + '{:.3f} GB out of {:.3f} GB total.'.format( + completed, + len(futures), + mem.used / 1e9, + mem.total / 1e9, + ) + ) try: future.result() except Exception as e: @@ -464,26 +658,22 @@ def _collect_flist(self, feature, subset_masked_meta, time_index, shape, msg += f'{futures[future]}' logger.exception(msg) raise RuntimeError(msg) from e - with RexOutputs(out_file, mode='r') as f: - target_ti = f.time_index - y_write_slice, x_write_slice = Collector.get_slices( - target_ti, target_masked_meta, time_index, - subset_masked_meta) - Collector._ensure_dset_in_output(out_file, feature) - with RexOutputs(out_file, mode='a') as f: - f[feature, y_write_slice, x_write_slice] = self.data - - logger.debug('Finished writing "{}" for row {} and col {} to: {}' - .format(feature, y_write_slice, x_write_slice, - os.path.basename(out_file))) + self._write_flist_data( + out_file, + feature, + time_index, + subset_masked_meta, + target_masked_meta, + ) else: - msg = ('No target coordinates found in masked meta. Skipping ' - f'collection for {file_paths}.') + msg = ( + 'No target coordinates found in masked meta. Skipping ' + f'collection for {file_paths}.' + ) logger.warning(msg) warn(msg) - @classmethod - def group_time_chunks(cls, file_paths, n_writes=None): + def group_time_chunks(self, file_paths, n_writes=None): """Group files by temporal_chunk_index. Assumes file_paths have a suffix format like _{temporal_chunk_index}_{spatial_chunk_index}.h5 @@ -503,17 +693,21 @@ def group_time_chunks(cls, file_paths, n_writes=None): file_split = {} for file in file_paths: t_chunk = file.split('_')[-2] - file_split[t_chunk] = file_split.get(t_chunk, []) + [file] + file_split[t_chunk] = [*file_split.get(t_chunk, []), file] file_chunks = [] for files in file_split.values(): file_chunks.append(files) - logger.debug(f'Split file list into {len(file_chunks)} chunks ' - 'according to temporal chunk indices') + logger.debug( + f'Split file list into {len(file_chunks)} chunks ' + 'according to temporal chunk indices' + ) if n_writes is not None: - msg = (f'n_writes ({n_writes}) must be less than or equal ' - f'to the number of temporal chunks ({len(file_chunks)}).') + msg = ( + f'n_writes ({n_writes}) must be less than or equal ' + f'to the number of temporal chunks ({len(file_chunks)}).' + ) assert n_writes < len(file_chunks), msg return file_chunks @@ -543,28 +737,42 @@ def get_flist_chunks(self, file_paths, n_writes=None, join_times=False): multiple steps. """ if join_times: - flist_chunks = self.group_time_chunks(file_paths, - n_writes=n_writes) + flist_chunks = self.group_time_chunks( + file_paths, n_writes=n_writes + ) else: flist_chunks = [[f] for f in file_paths] if n_writes is not None: flist_chunks = np.array_split(flist_chunks, n_writes) - flist_chunks = [np.concatenate(fp_chunk) - for fp_chunk in flist_chunks] - logger.debug(f'Split file list into {len(flist_chunks)} ' - f'chunks according to n_writes={n_writes}') + flist_chunks = [ + np.concatenate(fp_chunk) for fp_chunk in flist_chunks + ] + logger.debug( + f'Split file list into {len(flist_chunks)} ' + f'chunks according to n_writes={n_writes}' + ) return flist_chunks @classmethod - def collect(cls, file_paths, out_file, features, max_workers=None, - log_level=None, log_file=None, write_status=False, - job_name=None, join_times=False, target_final_meta_file=None, - n_writes=None, overwrite=True, threshold=1e-4): + def collect( + cls, + file_paths, + out_file, + features, + max_workers=None, + log_level=None, + log_file=None, + write_status=False, + job_name=None, + join_times=False, + target_final_meta_file=None, + n_writes=None, + overwrite=True, + threshold=1e-4, + ): """Collect data files from a dir to one output file. - Assumes the file list is chunked in time (row chunked). - Filename requirements: - Should end with ".h5" @@ -617,29 +825,33 @@ def collect(cls, file_paths, out_file, features, max_workers=None, """ t0 = time.time() - logger.info(f'Initializing collection for file_paths={file_paths}, ' - f'with max_workers={max_workers}.') + logger.info( + f'Initializing collection for file_paths={file_paths}, ' + f'with max_workers={max_workers}.' + ) if log_level is not None: - init_logger('sup3r.preprocessing', log_file=log_file, - log_level=log_level) - - logger.info(f'Using target_final_meta_file={target_final_meta_file}') + init_logger( + 'sup3r.preprocessing', log_file=log_file, log_level=log_level + ) if not os.path.exists(os.path.dirname(out_file)): os.makedirs(os.path.dirname(out_file), exist_ok=True) collector = cls(file_paths) - logger.info('Collecting {} files to {}'.format(len(collector.flist), - out_file)) + logger.info( + 'Collecting {} files to {}'.format(len(collector.flist), out_file) + ) if overwrite and os.path.exists(out_file): logger.info(f'overwrite=True, removing {out_file}.') os.remove(out_file) - out = collector._get_collection_attrs( - collector.flist, max_workers=max_workers, + out = collector.get_collection_attrs( + collector.flist, + max_workers=max_workers, target_final_meta_file=target_final_meta_file, - threshold=threshold) + threshold=threshold, + ) time_index, target_final_meta, target_masked_meta = out[:3] shape, global_attrs = out[3:] @@ -647,41 +859,68 @@ def collect(cls, file_paths, out_file, features, max_workers=None, logger.debug('Collecting dataset "{}".'.format(dset)) if join_times or n_writes is not None: flist_chunks = collector.get_flist_chunks( - collector.flist, n_writes=n_writes, join_times=join_times) + collector.flist, n_writes=n_writes, join_times=join_times + ) else: flist_chunks = [collector.flist] if not os.path.exists(out_file): - collector._init_h5(out_file, time_index, target_final_meta, - global_attrs) + collector._init_h5( + out_file, time_index, target_final_meta, global_attrs + ) if len(flist_chunks) == 1: - collector._collect_flist(dset, target_masked_meta, time_index, - shape, flist_chunks[0], out_file, - target_masked_meta, - max_workers=max_workers) + collector._collect_flist( + dset, + target_masked_meta, + time_index, + shape, + flist_chunks[0], + out_file, + target_masked_meta, + max_workers=max_workers, + ) else: for j, flist in enumerate(flist_chunks): - logger.info('Collecting file list chunk {} out of {} ' - .format(j + 1, len(flist_chunks))) - time_index, target_final_meta, masked_meta, shape, _ = \ - collector._get_collection_attrs( - flist, max_workers=max_workers, - target_final_meta_file=target_final_meta_file, - threshold=threshold) - collector._collect_flist(dset, masked_meta, time_index, - shape, flist, out_file, - target_masked_meta, - max_workers=max_workers) + logger.info( + 'Collecting file list chunk {} out of {} '.format( + j + 1, len(flist_chunks) + ) + ) + ( + time_index, + target_final_meta, + masked_meta, + shape, + _, + ) = collector.get_collection_attrs( + flist, + max_workers=max_workers, + target_final_meta_file=target_final_meta_file, + threshold=threshold, + ) + collector._collect_flist( + dset, + masked_meta, + time_index, + shape, + flist, + out_file, + target_masked_meta, + max_workers=max_workers, + ) if write_status and job_name is not None: - status = {'out_dir': os.path.dirname(out_file), - 'fout': out_file, - 'flist': collector.flist, - 'job_status': 'successful', - 'runtime': (time.time() - t0) / 60} - Status.make_job_file(os.path.dirname(out_file), 'collect', - job_name, status) + status = { + 'out_dir': os.path.dirname(out_file), + 'fout': out_file, + 'flist': collector.flist, + 'job_status': 'successful', + 'runtime': (time.time() - t0) / 60, + } + Status.make_job_file( + os.path.dirname(out_file), 'collect', job_name, status + ) logger.info('Finished file collection.') diff --git a/sup3r/preprocessing/batch_handling.py b/sup3r/preprocessing/batch_handling.py index 17b4449b2..5ed05fc96 100644 --- a/sup3r/preprocessing/batch_handling.py +++ b/sup3r/preprocessing/batch_handling.py @@ -92,14 +92,18 @@ def reduce_features(high_res, output_features_ind=None): # pylint: disable=W0613 @classmethod - def get_coarse_batch(cls, high_res, - s_enhance, t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - training_features=None, - smoothing=None, - smoothing_ignore=None): + def get_coarse_batch( + cls, + high_res, + s_enhance, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + training_features=None, + smoothing=None, + smoothing_ignore=None, + ): """Coarsen high res data and return Batch with high res and low res data @@ -149,11 +153,13 @@ def get_coarse_batch(cls, high_res, smoothing_ignore = [] if t_enhance != 1: - low_res = temporal_coarsening(low_res, t_enhance, - temporal_coarsening_method) + low_res = temporal_coarsening( + low_res, t_enhance, temporal_coarsening_method + ) - low_res = smooth_data(low_res, training_features, smoothing_ignore, - smoothing) + low_res = smooth_data( + low_res, training_features, smoothing_ignore, smoothing + ) high_res = cls.reduce_features(high_res, output_features_ind) batch = cls(low_res, high_res) @@ -166,11 +172,18 @@ class ValidationData: # Classes to use for handling an individual batch obj. BATCH_CLASS = Batch - def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, - temporal_coarsening_method='subsample', - output_features_ind=None, - output_features=None, - smoothing=None, smoothing_ignore=None): + def __init__( + self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + temporal_coarsening_method='subsample', + output_features_ind=None, + output_features=None, + smoothing=None, + smoothing_ignore=None, + ): """ Parameters ---------- @@ -212,8 +225,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.batch_size = batch_size self.sample_shape = handler_shapes[0] self.val_indices = self._get_val_indices() - self.max = np.ceil( - len(self.val_indices) / (batch_size)) + self.max = np.ceil(len(self.val_indices) / (batch_size)) self.s_enhance = s_enhance self.t_enhance = t_enhance self._remaining_observations = len(self.val_indices) @@ -240,14 +252,22 @@ def _get_val_indices(self): for i, h in enumerate(self.handlers): if h.val_data is not None: for _ in range(h.val_data.shape[2]): - spatial_slice = uniform_box_sampler(h.val_data, - self.sample_shape[:2]) - temporal_slice = uniform_time_sampler(h.val_data, - self.sample_shape[2]) - tuple_index = tuple([*spatial_slice, temporal_slice] - + [np.arange(h.val_data.shape[-1])]) - val_indices.append({'handler_index': i, - 'tuple_index': tuple_index}) + spatial_slice = uniform_box_sampler( + h.val_data, self.sample_shape[:2] + ) + temporal_slice = uniform_time_sampler( + h.val_data, self.sample_shape[2] + ) + tuple_index = tuple( + [ + *spatial_slice, + temporal_slice, + np.arange(h.val_data.shape[-1]), + ] + ) + val_indices.append( + {'handler_index': i, 'tuple_index': tuple_index} + ) return val_indices def any(self): @@ -268,10 +288,12 @@ def shape(self): time_steps = 0 for h in self.handlers: time_steps += h.val_data.shape[2] - return (self.handlers[0].val_data.shape[0], - self.handlers[0].val_data.shape[1], - time_steps, - self.handlers[0].val_data.shape[3]) + return ( + self.handlers[0].val_data.shape[0], + self.handlers[0].val_data.shape[1], + time_steps, + self.handlers[0].val_data.shape[3], + ) def __iter__(self): self._i = 0 @@ -302,13 +324,15 @@ def batch_next(self, high_res): batch : Batch """ return self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + output_features=self.output_features, + ) def __next__(self): """Get validation data batch @@ -321,23 +345,32 @@ def __next__(self): """ if self._remaining_observations > 0: if self._remaining_observations > self.batch_size: - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.handlers[0].shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.handlers[0].shape[-1], + ), + dtype=np.float32, + ) else: - high_res = np.zeros((self._remaining_observations, - self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.handlers[0].shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self._remaining_observations, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.handlers[0].shape[-1], + ), + dtype=np.float32, + ) for i in range(high_res.shape[0]): val_index = self.val_indices[self._i + i] high_res[i, ...] = self.handlers[ - val_index['handler_index']].val_data[ - val_index['tuple_index']] + val_index['handler_index'] + ].val_data[val_index['tuple_index']] self._remaining_observations -= 1 if self.sample_shape[2] == 1: @@ -357,11 +390,24 @@ class BatchHandler: BATCH_CLASS = Batch DATA_HANDLER_CLASS = None - def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, - means=None, stds=None, norm=True, n_batches=10, - temporal_coarsening_method='subsample', stdevs_file=None, - means_file=None, overwrite_stats=False, smoothing=None, - smoothing_ignore=None, worker_kwargs=None): + def __init__( + self, + data_handlers, + batch_size=8, + s_enhance=3, + t_enhance=1, + means=None, + stds=None, + norm=True, + n_batches=10, + temporal_coarsening_method='subsample', + stdevs_file=None, + means_file=None, + overwrite_stats=False, + smoothing=None, + smoothing_ignore=None, + worker_kwargs=None, + ): """ Parameters ---------- @@ -435,7 +481,7 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self._norm_workers = worker_kwargs.get('norm_workers', norm_workers) self._load_workers = worker_kwargs.get('load_workers', load_workers) - msg = ('All data handlers must have the same sample_shape') + msg = 'All data handlers must have the same sample_shape' handler_shapes = np.array([d.sample_shape for d in data_handlers]) assert np.all(handler_shapes[0] == handler_shapes), msg @@ -459,18 +505,23 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, self.overwrite_stats = overwrite_stats self.smoothing = smoothing self.smoothing_ignore = smoothing_ignore or [] - self.smoothed_features = [f for f in self.training_features - if f not in self.smoothing_ignore] + self.smoothed_features = [ + f for f in self.training_features if f not in self.smoothing_ignore + ] - logger.info(f'Initializing BatchHandler with smoothing={smoothing}. ' - f'Using stats_workers={self.stats_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'load_workers={self.load_workers}.') + logger.info( + f'Initializing BatchHandler with smoothing={smoothing}. ' + f'Using stats_workers={self.stats_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'load_workers={self.load_workers}.' + ) now = dt.now() self.parallel_load() - logger.debug(f'Finished loading data of shape {self.shape} ' - f'for BatchHandler in {dt.now() - now}.') + logger.debug( + f'Finished loading data of shape {self.shape} ' + f'for BatchHandler in {dt.now() - now}.' + ) log_mem(logger, log_level='INFO') if norm: @@ -479,13 +530,16 @@ def __init__(self, data_handlers, batch_size=8, s_enhance=3, t_enhance=1, logger.debug('Getting validation data for BatchHandler.') self.val_data = self.VAL_CLASS( - data_handlers, batch_size=batch_size, - s_enhance=s_enhance, t_enhance=t_enhance, + data_handlers, + batch_size=batch_size, + s_enhance=s_enhance, + t_enhance=t_enhance, temporal_coarsening_method=temporal_coarsening_method, output_features_ind=self.output_features_ind, output_features=self.output_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore) + smoothing_ignore=self.smoothing_ignore, + ) logger.info('Finished initializing BatchHandler.') log_mem(logger, log_level='INFO') @@ -499,16 +553,18 @@ def feature_mem(self): def stats_workers(self): """Get max workers for calculating stats based on memory usage""" proc_mem = self.feature_mem - stats_workers = estimate_max_workers(self._stats_workers, proc_mem, - len(self.data_handlers)) + stats_workers = estimate_max_workers( + self._stats_workers, proc_mem, len(self.data_handlers) + ) return stats_workers @property def load_workers(self): """Get max workers for loading data handler based on memory usage""" proc_mem = len(self.data_handlers[0].features) * self.feature_mem - max_workers = estimate_max_workers(self._load_workers, proc_mem, - len(self.data_handlers)) + max_workers = estimate_max_workers( + self._load_workers, proc_mem, len(self.data_handlers) + ) return max_workers @property @@ -516,8 +572,9 @@ def norm_workers(self): """Get max workers used for calculating and normalization across features""" proc_mem = 2 * self.feature_mem - norm_workers = estimate_max_workers(self._norm_workers, proc_mem, - len(self.training_features)) + norm_workers = estimate_max_workers( + self._norm_workers, proc_mem, len(self.training_features) + ) return norm_workers @property @@ -539,8 +596,11 @@ def output_features_ind(self): if self.training_features == self.output_features: return None else: - out = [i for i, feature in enumerate(self.training_features) - if feature in self.output_features] + out = [ + i + for i, feature in enumerate(self.training_features) + if feature in self.output_features + ] return out @property @@ -555,8 +615,12 @@ def shape(self): dimension """ time_steps = np.sum([h.shape[-2] for h in self.data_handlers]) - return (self.data_handlers[0].shape[0], self.data_handlers[0].shape[1], - time_steps, self.data_handlers[0].shape[-1]) + return ( + self.data_handlers[0].shape[0], + self.data_handlers[0].shape[1], + time_steps, + self.data_handlers[0].shape[-1], + ) def parallel_normalization(self): """Normalize data in all data handlers in parallel.""" @@ -573,19 +637,25 @@ def parallel_normalization(self): future = exe.submit(d.normalize, self.means, self.stds) futures[future] = i - logger.info(f'Started normalizing {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.') + logger.info( + f'Started normalizing {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.' + ) for i, _ in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ('Error normalizing data handler number ' - f'{futures[future]}') + msg = ( + 'Error normalizing data handler number ' + f'{futures[future]}' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of {len(futures)} data handlers' - ' normalized.') + logger.debug( + f'{i+1} out of {len(futures)} data handlers' + ' normalized.' + ) def parallel_load(self): """Load data handler data in parallel""" @@ -603,25 +673,31 @@ def parallel_load(self): future = exe.submit(d.load_cached_data) futures[future] = i - logger.info(f'Started loading all {len(self.data_handlers)} ' - f'data handlers in {dt.now() - now}.') + logger.info( + f'Started loading all {len(self.data_handlers)} ' + f'data handlers in {dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ('Error loading data handler number ' - f'{futures[future]}') + msg = ( + 'Error loading data handler number ' + f'{futures[future]}' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of {len(futures)} handlers ' - 'loaded.') + logger.debug( + f'{i+1} out of {len(futures)} handlers ' 'loaded.' + ) def parallel_stats(self): - """Get standard deviations and means for training features in parallel. - """ - logger.info(f'Calculating stats for {len(self.training_features)} ' - 'features.') + """Get standard deviations and means for training features in + parallel.""" + logger.info( + f'Calculating stats for {len(self.training_features)} ' 'features.' + ) max_workers = self.norm_workers if max_workers == 1: for f in self.training_features: @@ -634,21 +710,27 @@ def parallel_stats(self): future = exe.submit(self.get_stats_for_feature, f) futures[future] = i - logger.info('Started calculating stats for ' - f'{len(self.training_features)} features in ' - f'{dt.now() - now}.') + logger.info( + 'Started calculating stats for ' + f'{len(self.training_features)} features in ' + f'{dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ('Error calculating stats for ' - f'{self.training_features[futures[future]]}') + msg = ( + 'Error calculating stats for ' + f'{self.training_features[futures[future]]}' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of ' - f'{len(self.training_features)} stats ' - 'calculated.') + logger.debug( + f'{i+1} out of ' + f'{len(self.training_features)} stats ' + 'calculated.' + ) def __len__(self): """Use user input of n_batches to specify length @@ -671,11 +753,11 @@ def check_cached_stats(self): stds : ndarray Array of stdevs for each feature """ - stdevs_check = (self.stdevs_file is not None - and not self.overwrite_stats) + stdevs_check = ( + self.stdevs_file is not None and not self.overwrite_stats + ) stdevs_check = stdevs_check and os.path.exists(self.stdevs_file) - means_check = (self.means_file is not None - and not self.overwrite_stats) + means_check = self.means_file is not None and not self.overwrite_stats means_check = means_check and os.path.exists(self.means_file) if stdevs_check and means_check: logger.info(f'Loading stdevs from {self.stdevs_file}') @@ -685,10 +767,12 @@ def check_cached_stats(self): with open(self.means_file, 'rb') as fh: self.means = pickle.load(fh) - msg = ('The training features and cached statistics are ' - 'incompatible. Number of training features is ' - f'{len(self.training_features)} and number of stats is' - f' {len(self.stds)}') + msg = ( + 'The training features and cached statistics are ' + 'incompatible. Number of training features is ' + f'{len(self.training_features)} and number of stats is' + f' {len(self.stds)}' + ) check = len(self.means) == len(self.training_features) check = check and (len(self.stds) == len(self.training_features)) assert check, msg @@ -740,7 +824,8 @@ def get_handler_mean(self, feature_idx, handler_idx): Feature mean """ return np.nanmean( - self.data_handlers[handler_idx].data[..., feature_idx]) + self.data_handlers[handler_idx].data[..., feature_idx] + ) def get_handler_variance(self, feature_idx, handler_idx, mean): """Get feature variance for a given handler @@ -803,14 +888,18 @@ def get_means_for_feature(self, feature, max_workers=None): future = exe.submit(self.get_handler_mean, idx, didx) futures[future] = didx - logger.info('Started calculating means for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.') + logger.info( + 'Started calculating means for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): self.means[idx] += future.result() - logger.debug(f'{i+1} out of {len(self.data_handlers)} ' - 'means calculated.') + logger.debug( + f'{i+1} out of {len(self.data_handlers)} ' + 'means calculated.' + ) self.means[idx] /= len(self.data_handlers) return self.means[idx] @@ -829,25 +918,31 @@ def get_stdevs_for_feature(self, feature, max_workers=None): logger.debug(f'Calculating stdev for {feature}') if max_workers == 1: for didx, _ in enumerate(self.data_handlers): - self.stds[idx] += self.get_handler_variance(idx, didx, - self.means[idx]) + self.stds[idx] += self.get_handler_variance( + idx, didx, self.means[idx] + ) else: with ThreadPoolExecutor(max_workers=max_workers) as exe: futures = {} now = dt.now() for didx, _ in enumerate(self.data_handlers): - future = exe.submit(self.get_handler_variance, idx, didx, - self.means[idx]) + future = exe.submit( + self.get_handler_variance, idx, didx, self.means[idx] + ) futures[future] = didx - logger.info('Started calculating stdevs for ' - f'{len(self.data_handlers)} data_handlers in ' - f'{dt.now() - now}.') + logger.info( + 'Started calculating stdevs for ' + f'{len(self.data_handlers)} data_handlers in ' + f'{dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): self.stds[idx] += future.result() - logger.debug(f'{i+1} out of {len(self.data_handlers)} ' - 'stdevs calculated.') + logger.debug( + f'{i+1} out of {len(self.data_handlers)} ' + 'stdevs calculated.' + ) self.stds[idx] /= len(self.data_handlers) self.stds[idx] = np.sqrt(self.stds[idx]) return self.stds[idx] @@ -867,16 +962,19 @@ def normalize(self, means=None, stds=None): if means is None or stds is None: self.get_stats() elif means is not None and stds is not None: - if (not np.array_equal(means, self.means) - or not np.array_equal(stds, self.stds)): + if not np.array_equal(means, self.means) or not np.array_equal( + stds, self.stds + ): self.unnormalize() self.means = means self.stds = stds now = dt.now() logger.info('Normalizing data in each data handler.') self.parallel_normalization() - logger.info('Finished normalizing data in all data handlers in ' - f'{dt.now() - now}.') + logger.info( + 'Finished normalizing data in all data handlers in ' + f'{dt.now() - now}.' + ) def unnormalize(self): """Remove normalization from stored means and stds""" @@ -901,22 +999,32 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.sample_shape[2], - self.shape[-1]), dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next() self.current_batch_indices.append(handler.current_obs_index) batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, t_enhance=self.t_enhance, + high_res, + self.s_enhance, + t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, output_features_ind=self.output_features_ind, output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore) + smoothing_ignore=self.smoothing_ignore, + ) self._i += 1 return batch @@ -975,7 +1083,8 @@ def __next__(self): self.current_batch_indices.append(handler.current_obs_index) obs_hourly = self.BATCH_CLASS.reduce_features( - obs_hourly, self.output_features_ind) + obs_hourly, self.output_features_ind + ) if low_res is None: lr_shape = (self.batch_size, *obs_daily_avg.shape) @@ -989,21 +1098,25 @@ def __next__(self): high_res = self.reduce_high_res_sub_daily(high_res) low_res = spatial_coarsening(low_res, self.s_enhance) - if (self.output_features is not None - and 'clearsky_ratio' in self.output_features): + if ( + self.output_features is not None + and 'clearsky_ratio' in self.output_features + ): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: - feat_iter = [j for j in range(low_res.shape[-1]) - if self.training_features[j] - not in self.smoothing_ignore] + feat_iter = [ + j + for j in range(low_res.shape[-1]) + if self.training_features[j] not in self.smoothing_ignore + ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], - self.smoothing, - mode='nearest') + low_res[i, ..., j] = gaussian_filter( + low_res[i, ..., j], self.smoothing, mode='nearest' + ) batch = self.BATCH_CLASS(low_res, high_res) @@ -1084,9 +1197,12 @@ def __next__(self): hr_shape = (self.batch_size, *obs_daily_avg.shape) high_res = np.zeros(hr_shape, dtype=np.float32) - msg = ('SpatialBatchHandlerCC can only use n_temporal==1 ' - 'but received HR shape {} with n_temporal={}.' - .format(hr_shape, hr_shape[3])) + msg = ( + 'SpatialBatchHandlerCC can only use n_temporal==1 ' + 'but received HR shape {} with n_temporal={}.'.format( + hr_shape, hr_shape[3] + ) + ) assert hr_shape[3] == 1, msg high_res[i] = obs_daily_avg @@ -1096,23 +1212,28 @@ def __next__(self): high_res = high_res[:, :, :, 0, :] high_res = self.BATCH_CLASS.reduce_features( - high_res, self.output_features_ind) + high_res, self.output_features_ind + ) - if (self.output_features is not None - and 'clearsky_ratio' in self.output_features): + if ( + self.output_features is not None + and 'clearsky_ratio' in self.output_features + ): i_cs = self.output_features.index('clearsky_ratio') if np.isnan(high_res[..., i_cs]).any(): high_res[..., i_cs] = nn_fill_array(high_res[..., i_cs]) if self.smoothing is not None: - feat_iter = [j for j in range(low_res.shape[-1]) - if self.training_features[j] - not in self.smoothing_ignore] + feat_iter = [ + j + for j in range(low_res.shape[-1]) + if self.training_features[j] not in self.smoothing_ignore + ] for i in range(low_res.shape[0]): for j in feat_iter: - low_res[i, ..., j] = gaussian_filter(low_res[i, ..., j], - self.smoothing, - mode='nearest') + low_res[i, ..., j] = gaussian_filter( + low_res[i, ..., j], self.smoothing, mode='nearest' + ) batch = self.BATCH_CLASS(low_res, high_res) @@ -1125,21 +1246,28 @@ class SpatialBatchHandler(BatchHandler): def __next__(self): if self._i < self.n_batches: - handler_index = np.random.randint( - 0, len(self.data_handlers)) + handler_index = np.random.randint(0, len(self.data_handlers)) handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next()[..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, output_features_ind=self.output_features_ind, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore) + smoothing_ignore=self.smoothing_ignore, + ) self._i += 1 return batch @@ -1172,15 +1300,23 @@ def _get_val_indices(self): h = self.handlers[h_idx] for _ in range(self.batch_size): spatial_slice = uniform_box_sampler( - h.data, self.sample_shape[:2]) + h.data, self.sample_shape[:2] + ) weights = np.zeros(self.N_TIME_BINS) weights[t] = 1 temporal_slice = weighted_time_sampler( - h.data, self.sample_shape[2], weights) - tuple_index = tuple([*spatial_slice, temporal_slice] - + [np.arange(h.data.shape[-1])]) - val_indices[t].append({'handler_index': h_idx, - 'tuple_index': tuple_index}) + h.data, self.sample_shape[2], weights + ) + tuple_index = tuple( + [ + *spatial_slice, + temporal_slice, + np.arange(h.data.shape[-1]), + ] + ) + val_indices[t].append( + {'handler_index': h_idx, 'tuple_index': tuple_index} + ) for s in range(self.N_SPACE_BINS): val_indices[s + self.N_TIME_BINS] = [] h_idx = np.random.choice(np.arange(len(self.handlers))) @@ -1189,34 +1325,51 @@ def _get_val_indices(self): weights = np.zeros(self.N_SPACE_BINS) weights[s] = 1 spatial_slice = weighted_box_sampler( - h.data, self.sample_shape[:2], weights) + h.data, self.sample_shape[:2], weights + ) temporal_slice = uniform_time_sampler( - h.data, self.sample_shape[2]) - tuple_index = tuple([*spatial_slice, temporal_slice] - + [np.arange(h.data.shape[-1])]) + h.data, self.sample_shape[2] + ) + tuple_index = tuple( + [ + *spatial_slice, + temporal_slice, + np.arange(h.data.shape[-1]), + ] + ) val_indices[s + self.N_TIME_BINS].append( - {'handler_index': h_idx, 'tuple_index': tuple_index}) + {'handler_index': h_idx, 'tuple_index': tuple_index} + ) return val_indices def __next__(self): if self._i < len(self.val_indices.keys()): - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], - self.sample_shape[2], - self.handlers[0].shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.handlers[0].shape[-1], + ), + dtype=np.float32, + ) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): - high_res[i, ...] = self.handlers[ - idx['handler_index']].data[idx['tuple_index']] + high_res[i, ...] = self.handlers[idx['handler_index']].data[ + idx['tuple_index'] + ] batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, t_enhance=self.t_enhance, + high_res, + self.s_enhance, + t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + output_features=self.output_features, + ) self._i += 1 return batch else: @@ -1236,21 +1389,29 @@ class ValidationDataSpatialDC(ValidationDataDC): def __next__(self): if self._i < len(self.val_indices.keys()): - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], - self.handlers[0].shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.handlers[0].shape[-1], + ), + dtype=np.float32, + ) val_indices = self.val_indices[self._i] for i, idx in enumerate(val_indices): - high_res[i, ...] = self.handlers[ - idx['handler_index']].data[idx['tuple_index']][..., 0, :] + high_res[i, ...] = self.handlers[idx['handler_index']].data[ + idx['tuple_index'] + ][..., 0, :] batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, output_features_ind=self.output_features_ind, smoothing=self.smoothing, smoothing_ignore=self.smoothing_ignore, - output_features=self.output_features) + output_features=self.output_features, + ) self._i += 1 return batch else: @@ -1279,13 +1440,16 @@ def __init__(self, *args, **kwargs): self.temporal_weights /= np.sum(self.temporal_weights) self.old_temporal_weights = [0] * self.val_data.N_TIME_BINS bin_range = self.data_handlers[0].data.shape[2] - bin_range -= (self.sample_shape[2] - 1) - self.temporal_bins = np.array_split(np.arange(0, bin_range), - self.val_data.N_TIME_BINS) + bin_range -= self.sample_shape[2] - 1 + self.temporal_bins = np.array_split( + np.arange(0, bin_range), self.val_data.N_TIME_BINS + ) self.temporal_bins = [b[0] for b in self.temporal_bins] - logger.info('Using temporal weights: ' - f'{[round(w, 3) for w in self.temporal_weights]}') + logger.info( + 'Using temporal weights: ' + f'{[round(w, 3) for w in self.temporal_weights]}' + ) self.temporal_sample_record = [0] * self.val_data.N_TIME_BINS self.norm_temporal_record = [0] * self.val_data.N_TIME_BINS @@ -1307,32 +1471,44 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.sample_shape[2], - self.shape[-1]), dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.sample_shape[2], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - temporal_weights=self.temporal_weights) + temporal_weights=self.temporal_weights + ) self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, t_enhance=self.t_enhance, + high_res, + self.s_enhance, + t_enhance=self.t_enhance, temporal_coarsening_method=self.temporal_coarsening_method, output_features_ind=self.output_features_ind, output_features=self.output_features, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore) + smoothing_ignore=self.smoothing_ignore, + ) self._i += 1 return batch else: total_count = self.n_batches * self.batch_size - self.norm_temporal_record = [c / total_count for c - in self.temporal_sample_record.copy()] + self.norm_temporal_record = [ + c / total_count for c in self.temporal_sample_record.copy() + ] self.old_temporal_weights = self.temporal_weights.copy() raise StopIteration @@ -1363,12 +1539,15 @@ def __init__(self, *args, **kwargs): self.max_cols = self.data_handlers[0].data.shape[1] + 1 self.max_cols -= self.sample_shape[1] bin_range = self.max_rows * self.max_cols - self.spatial_bins = np.array_split(np.arange(0, bin_range), - self.val_data.N_SPACE_BINS) + self.spatial_bins = np.array_split( + np.arange(0, bin_range), self.val_data.N_SPACE_BINS + ) self.spatial_bins = [b[0] for b in self.spatial_bins] - logger.info('Using spatial weights: ' - f'{[round(w, 3) for w in self.spatial_weights]}') + logger.info( + 'Using spatial weights: ' + f'{[round(w, 3) for w in self.spatial_weights]}' + ) self.spatial_sample_record = [0] * self.val_data.N_SPACE_BINS self.norm_spatial_record = [0] * self.val_data.N_SPACE_BINS @@ -1393,29 +1572,39 @@ def __next__(self): handler_index = np.random.randint(0, len(self.data_handlers)) self.current_handler_index = handler_index handler = self.data_handlers[handler_index] - high_res = np.zeros((self.batch_size, self.sample_shape[0], - self.sample_shape[1], self.shape[-1]), - dtype=np.float32) + high_res = np.zeros( + ( + self.batch_size, + self.sample_shape[0], + self.sample_shape[1], + self.shape[-1], + ), + dtype=np.float32, + ) for i in range(self.batch_size): high_res[i, ...] = handler.get_next( - spatial_weights=self.spatial_weights)[..., 0, :] + spatial_weights=self.spatial_weights + )[..., 0, :] self.current_batch_indices.append(handler.current_obs_index) self.update_training_sample_record() batch = self.BATCH_CLASS.get_coarse_batch( - high_res, self.s_enhance, + high_res, + self.s_enhance, output_features_ind=self.output_features_ind, training_features=self.training_features, smoothing=self.smoothing, - smoothing_ignore=self.smoothing_ignore) + smoothing_ignore=self.smoothing_ignore, + ) self._i += 1 return batch else: total_count = self.n_batches * self.batch_size - self.norm_spatial_record = [c / total_count - for c in self.spatial_sample_record] + self.norm_spatial_record = [ + c / total_count for c in self.spatial_sample_record + ] self.old_spatial_weights = self.spatial_weights.copy() raise StopIteration diff --git a/sup3r/preprocessing/data_handling.py b/sup3r/preprocessing/data_handling.py index 9245dc8cc..12b1d2aa6 100644 --- a/sup3r/preprocessing/data_handling.py +++ b/sup3r/preprocessing/data_handling.py @@ -12,6 +12,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from datetime import datetime as dt from fnmatch import fnmatch +from typing import ClassVar import numpy as np import pandas as pd @@ -23,6 +24,7 @@ from scipy.spatial import KDTree from scipy.stats import mode +from sup3r.bias.bias_transforms import get_spatial_bc_factors from sup3r.preprocessing.feature_handling import ( BVFreqMon, BVFreqSquaredH5, @@ -78,8 +80,14 @@ class InputMixIn: """MixIn class with properties and methods for handling the spatiotemporal data domain to extract from source data.""" - def __init__(self, target, shape, raster_file=None, raster_index=None, - temporal_slice=slice(None, None, 1)): + def __init__( + self, + target, + shape, + raster_file=None, + raster_index=None, + temporal_slice=slice(None, None, 1), + ): """Provide properties of the spatiotemporal data domain Parameters @@ -144,8 +152,11 @@ def single_ts_files(self): if self._single_ts_files is None: logger.debug('Checking if input files are single timestep.') t_steps = self.get_time_index(self.file_paths[:1], max_workers=1) - check = (len(self._file_paths) == len(self.raw_time_index) - and t_steps is not None and len(t_steps) == 1) + check = ( + len(self._file_paths) == len(self.raw_time_index) + and t_steps is not None + and len(t_steps) == 1 + ) self._single_ts_files = check return self._single_ts_files @@ -207,8 +218,10 @@ def input_file_info(self): message to append to log output that does not include a huge info dump of file paths """ - msg = (f'source files with dates from {self.raw_time_index[0]} to ' - f'{self.raw_time_index[-1]}') + msg = ( + f'source files with dates from {self.raw_time_index[0]} to ' + f'{self.raw_time_index[-1]}' + ) return msg @property @@ -229,22 +242,26 @@ def temporal_slice(self, temporal_slice): elements and no more than three, corresponding to the inputs of slice() """ - msg = ('temporal_slice must be tuple, list, or slice') + msg = 'temporal_slice must be tuple, list, or slice' assert isinstance(temporal_slice, (tuple, list, slice)), msg if isinstance(temporal_slice, slice): self._temporal_slice = temporal_slice else: check = len(temporal_slice) <= 3 - msg = ('If providing list or tuple for temporal_slice length must ' - 'be <= 3') + msg = ( + 'If providing list or tuple for temporal_slice length must ' + 'be <= 3' + ) assert check, msg self._temporal_slice = slice(*temporal_slice) if self._temporal_slice.step is None: - self._temporal_slice = slice(self._temporal_slice.start, - self._temporal_slice.stop, 1) + self._temporal_slice = slice( + self._temporal_slice.start, self._temporal_slice.stop, 1 + ) if self._temporal_slice.start is None: - self._temporal_slice = slice(0, self._temporal_slice.stop, - self._temporal_slice.step) + self._temporal_slice = slice( + 0, self._temporal_slice.stop, self._temporal_slice.step + ) @property def file_paths(self): @@ -269,8 +286,10 @@ def file_paths(self, file_paths): else: self._file_paths = [self._file_paths] - msg = ('No valid files provided to DataHandler. ' - f'Received file_paths={file_paths}. Aborting.') + msg = ( + 'No valid files provided to DataHandler. ' + f'Received file_paths={file_paths}. Aborting.' + ) assert file_paths is not None and len(self._file_paths) > 0, msg self._file_paths = sorted(self._file_paths) @@ -301,7 +320,8 @@ def cache_pattern(self): self._cache_pattern += '.pkl' if '{feature}' not in self._cache_pattern: self._cache_pattern = self._cache_pattern.replace( - '.pkl', '_{feature}.pkl') + '.pkl', '_{feature}.pkl' + ) basedir = os.path.dirname(self._cache_pattern) if not os.path.exists(basedir): os.makedirs(basedir, exist_ok=True) @@ -316,14 +336,17 @@ def cache_pattern(self, cache_pattern): def need_full_domain(self): """Check whether we need to get the full lat/lon grid to determine target and shape values""" - no_raster_file = (self.raster_file is None - or not os.path.exists(self.raster_file)) - no_target_shape = (self._target is None or self._grid_shape is None) + no_raster_file = self.raster_file is None or not os.path.exists( + self.raster_file + ) + no_target_shape = self._target is None or self._grid_shape is None need_full = no_raster_file and no_target_shape if need_full: - logger.info('Target + shape not specified. Getting full domain ' - f'for {self.file_paths[0]}.') + logger.info( + 'Target + shape not specified. Getting full domain ' + f'for {self.file_paths[0]}.' + ) return need_full @@ -344,8 +367,9 @@ def raw_lat_lon(self): ------- ndarray """ - raster_file_exists = (self.raster_file is not None - and os.path.exists(self.raster_file)) + raster_file_exists = self.raster_file is not None and os.path.exists( + self.raster_file + ) if self.full_raw_lat_lon is not None and raster_file_exists: self._raw_lat_lon = self.full_raw_lat_lon[self.raster_index] @@ -354,9 +378,9 @@ def raw_lat_lon(self): self._raw_lat_lon = self.full_raw_lat_lon if self._raw_lat_lon is None: - self._raw_lat_lon = self.get_lat_lon(self.file_paths[0:1], - self.raster_index, - invert_lat=False) + self._raw_lat_lon = self.get_lat_lon( + self.file_paths[0:1], self.raster_index, invert_lat=False + ) return self._raw_lat_lon @property @@ -397,7 +421,7 @@ def invert_lat(self): of the grid is at idx=(-1, 0) instead of idx=(0, 0)""" if self._invert_lat is None: lat_lon = self.raw_lat_lon - self._invert_lat = (not self.lats_are_descending(lat_lon)) + self._invert_lat = not self.lats_are_descending(lat_lon) return self._invert_lat @property @@ -484,19 +508,24 @@ def raw_time_index(self): time index for the raw input data.""" if self._raw_time_index is None: - check = (self.time_index_file is not None - and os.path.exists(self.time_index_file) - and not self.overwrite_ti_cache) + check = ( + self.time_index_file is not None + and os.path.exists(self.time_index_file) + and not self.overwrite_ti_cache + ) if check: - logger.debug('Loading raw_time_index from ' - f'{self.time_index_file}') + logger.debug( + 'Loading raw_time_index from ' f'{self.time_index_file}' + ) with open(self.time_index_file, 'rb') as f: self._raw_time_index = pd.DatetimeIndex(pickle.load(f)) else: self._raw_time_index = self._build_and_cache_time_index() - check = (self._raw_time_index is not None - and (self._raw_time_index.hour == 12).all()) + check = ( + self._raw_time_index is not None + and (self._raw_time_index.hour == 12).all() + ) if check: self._raw_time_index -= pd.Timedelta(12, 'h') elif self._raw_time_index is None: @@ -509,20 +538,24 @@ def raw_time_index(self): def time_index_conflict_check(self): """Check if the number of input files and the length of the time index is the same""" - msg = (f'Number of time steps ({len(self._raw_time_index)}) and files ' - f'({self.raw_tsteps}) conflict!') + msg = ( + f'Number of time steps ({len(self._raw_time_index)}) and files ' + f'({self.raw_tsteps}) conflict!' + ) check = len(self._raw_time_index) == self.raw_tsteps assert check, msg def _build_and_cache_time_index(self): """Build time index and cache if time_index_file is not None""" now = dt.now() - logger.debug(f'Getting time index for {len(self.file_paths)} ' - f'input files. Using ti_workers={self.ti_workers}' - f' and res_kwargs={self.res_kwargs}') - self._raw_time_index = self.get_time_index(self.file_paths, - max_workers=self.ti_workers, - **self.res_kwargs) + logger.debug( + f'Getting time index for {len(self.file_paths)} ' + f'input files. Using ti_workers={self.ti_workers}' + f' and res_kwargs={self.res_kwargs}' + ) + self._raw_time_index = self.get_time_index( + self.file_paths, max_workers=self.ti_workers, **self.res_kwargs + ) if self.time_index_file is not None: logger.debug(f'Saved raw_time_index to {self.time_index_file}') @@ -595,18 +628,40 @@ class DataHandler(FeatureHandler, InputMixIn): # model but are not part of the synthetic output and are not sent to the # discriminator. These are case-insensitive and follow the Unix shell-style # wildcard format. - TRAIN_ONLY_FEATURES = ('BVF*', 'inversemoninobukhovlength_*', 'RMOL', - 'topography') - - def __init__(self, file_paths, features, target=None, shape=None, - max_delta=20, temporal_slice=slice(None, None, 1), - hr_spatial_coarsen=None, time_roll=0, val_split=0.05, - sample_shape=(10, 10, 1), raster_file=None, raster_index=None, - shuffle_time=False, time_chunk_size=None, cache_pattern=None, - overwrite_cache=False, overwrite_ti_cache=False, - load_cached=False, train_only_features=None, - handle_features=None, single_ts_files=None, mask_nan=False, - worker_kwargs=None, res_kwargs=None): + TRAIN_ONLY_FEATURES = ( + 'BVF*', + 'inversemoninobukhovlength_*', + 'RMOL', + 'topography', + ) + + def __init__( + self, + file_paths, + features, + target=None, + shape=None, + max_delta=20, + temporal_slice=slice(None, None, 1), + hr_spatial_coarsen=None, + time_roll=0, + val_split=0.0, + sample_shape=(10, 10, 1), + raster_file=None, + raster_index=None, + shuffle_time=False, + time_chunk_size=None, + cache_pattern=None, + overwrite_cache=False, + overwrite_ti_cache=False, + load_cached=False, + train_only_features=None, + handle_features=None, + single_ts_files=None, + mask_nan=False, + worker_kwargs=None, + res_kwargs=None, + ): """ Parameters ---------- @@ -727,14 +782,19 @@ def __init__(self, file_paths, features, target=None, shape=None, 'chunks': {'south_north': 120, 'west_east': 120}} which then gets passed to xr.open_mfdataset(file, **res_kwargs) """ - InputMixIn.__init__(self, target=target, shape=shape, - raster_file=raster_file, - raster_index=raster_index, - temporal_slice=temporal_slice) + InputMixIn.__init__( + self, + target=target, + shape=shape, + raster_file=raster_file, + raster_index=raster_index, + temporal_slice=temporal_slice, + ) self.file_paths = file_paths - self.features = (features if isinstance(features, (list, tuple)) - else [features]) + self.features = ( + features if isinstance(features, (list, tuple)) else [features] + ) self.val_time_index = None self.max_delta = max_delta self.val_split = val_split @@ -767,34 +827,48 @@ def __init__(self, file_paths, features, target=None, shape=None, self._norm_workers = worker_kwargs.get('norm_workers', None) self._load_workers = worker_kwargs.get('load_workers', None) self._compute_workers = worker_kwargs.get('compute_workers', None) - self._worker_attrs = ['_ti_workers', '_norm_workers', - '_compute_workers', '_extract_workers', - '_load_workers'] + self._worker_attrs = [ + '_ti_workers', + '_norm_workers', + '_compute_workers', + '_extract_workers', + '_load_workers', + ] self.preflight() - try_load = (cache_pattern is not None - and not self.overwrite_cache - and all(os.path.exists(fp) for fp in self.cache_files)) + try_load = ( + cache_pattern is not None + and not self.overwrite_cache + and all(os.path.exists(fp) for fp in self.cache_files) + ) - overwrite = (self.overwrite_cache - and self.cache_files is not None - and all(os.path.exists(fp) for fp in self.cache_files)) + overwrite = ( + self.overwrite_cache + and self.cache_files is not None + and all(os.path.exists(fp) for fp in self.cache_files) + ) if try_load and self.load_cached: - logger.info(f'All {self.cache_files} exist. Loading from cache ' - f'instead of extracting from source files.') + logger.info( + f'All {self.cache_files} exist. Loading from cache ' + f'instead of extracting from source files.' + ) self.load_cached_data() elif try_load and not self.load_cached: self.clear_data() - logger.info(f'All {self.cache_files} exist. Call ' - 'load_cached_data() or use load_cache=True to load ' - 'this data from cache files.') + logger.info( + f'All {self.cache_files} exist. Call ' + 'load_cached_data() or use load_cache=True to load ' + 'this data from cache files.' + ) else: if overwrite: - logger.info(f'{self.cache_files} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.') + logger.info( + f'{self.cache_files} exists but overwrite_cache ' + 'is set to True. Proceeding with extraction.' + ) self._raster_size_check() self._run_data_init_if_needed() @@ -807,8 +881,11 @@ def __init__(self, file_paths, features, target=None, shape=None, if mask_nan: nan_mask = np.isnan(self.data).any(axis=(0, 1, 3)) - logger.info('Removing {} out of {} timesteps due to NaNs' - .format(nan_mask.sum(), self.data.shape[2])) + logger.info( + 'Removing {} out of {} timesteps due to NaNs'.format( + nan_mask.sum(), self.data.shape[2] + ) + ) self.data = self.data[:, :, ~nan_mask, :] logger.info('Finished intializing DataHandler.') @@ -819,20 +896,24 @@ def _run_data_init_if_needed(self): extraction""" if any(self.features): self.data = self.run_all_data_init() - nan_perc = (100 * np.isnan(self.data).sum() / self.data.size) + nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: - msg = ('Data has {:.2f}% NaN values!'.format(nan_perc)) + msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) logger.warning(msg) warnings.warn(msg) def _raster_size_check(self): """Check if the sample_shape is larger than the requested raster size""" - bad_shape = (self.sample_shape[0] > self.grid_shape[0] - and self.sample_shape[1] > self.grid_shape[1]) + bad_shape = ( + self.sample_shape[0] > self.grid_shape[0] + and self.sample_shape[1] > self.grid_shape[1] + ) if bad_shape: - msg = (f'spatial_sample_shape {self.sample_shape[:2]} is ' - f'larger than the raster size {self.grid_shape}') + msg = ( + f'spatial_sample_shape {self.sample_shape[:2]} is ' + f'larger than the raster size {self.grid_shape}' + ) logger.warning(msg) warnings.warn(msg) @@ -842,11 +923,17 @@ def _val_split_check(self): if self.data is not None and self.val_split > 0.0: self.data, self.val_data = self.split_data() - msg = (f'Validation data has shape={self.val_data.shape} ' - f'and sample_shape={self.sample_shape}. Use a smaller ' - 'sample_shape and/or larger val_split.') - check = any(val_size < samp_size for val_size, samp_size - in zip(self.val_data.shape, self.sample_shape)) + msg = ( + f'Validation data has shape={self.val_data.shape} ' + f'and sample_shape={self.sample_shape}. Use a smaller ' + 'sample_shape and/or larger val_split.' + ) + check = any( + val_size < samp_size + for val_size, samp_size in zip( + self.val_data.shape, self.sample_shape + ) + ) if check: logger.warning(msg) warnings.warn(msg) @@ -900,22 +987,28 @@ def extract_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers(self._extract_workers, proc_mem, - n_procs) + extract_workers = estimate_max_workers( + self._extract_workers, proc_mem, n_procs + ) return extract_workers @property def compute_workers(self): """Get upper bound for compute workers based on memory limits. Used to compute derived features from source dataset.""" - proc_mem = int(np.ceil(len(self.extract_features) - / np.maximum(len(self.derive_features), 1))) + proc_mem = int( + np.ceil( + len(self.extract_features) + / np.maximum(len(self.derive_features), 1) + ) + ) proc_mem *= 4 * self.grid_mem * len(self.time_index) proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.derive_features) n_procs = int(np.ceil(n_procs)) - compute_workers = estimate_max_workers(self._compute_workers, proc_mem, - n_procs) + compute_workers = estimate_max_workers( + self._compute_workers, proc_mem, n_procs + ) return compute_workers @property @@ -926,17 +1019,18 @@ def load_workers(self): n_procs = 1 if self.cache_files is not None: n_procs = len(self.cache_files) - load_workers = estimate_max_workers(self._load_workers, proc_mem, - n_procs) + load_workers = estimate_max_workers( + self._load_workers, proc_mem, n_procs + ) return load_workers @property def norm_workers(self): """Get upper bound on workers used for normalization.""" if self.data is not None: - norm_workers = estimate_max_workers(self._norm_workers, - 2 * self.feature_mem, - self.shape[-1]) + norm_workers = estimate_max_workers( + self._norm_workers, 2 * self.feature_mem, self.shape[-1] + ) else: norm_workers = self._norm_workers return norm_workers @@ -955,9 +1049,11 @@ def time_chunks(self): if self.is_time_independent: self._time_chunks = [slice(None)] else: - self._time_chunks = get_chunk_slices(len(self.raw_time_index), - self.time_chunk_size, - self.temporal_slice) + self._time_chunks = get_chunk_slices( + len(self.raw_time_index), + self.time_chunk_size, + self.temporal_slice, + ) return self._time_chunks @property @@ -982,10 +1078,13 @@ def time_chunk_size(self): if step_mem == 0: self._time_chunk_size = self.n_tsteps else: - self._time_chunk_size = np.min([int(1e9 / step_mem), - self.n_tsteps]) - logger.info('time_chunk_size arg not specified. Using ' - f'{self._time_chunk_size}.') + self._time_chunk_size = np.min( + [int(1e9 / step_mem), self.n_tsteps] + ) + logger.info( + 'time_chunk_size arg not specified. Using ' + f'{self._time_chunk_size}.' + ) return self._time_chunk_size @property @@ -1040,25 +1139,34 @@ def noncached_features(self): """Get list of features needing extraction or derivation""" if self._noncached_features is None: self._noncached_features = self.check_cached_features( - self.features, cache_files=self.cache_files, + self.features, + cache_files=self.cache_files, overwrite_cache=self.overwrite_cache, - load_cached=self.load_cached) + load_cached=self.load_cached, + ) return self._noncached_features @property def extract_features(self): """Features to extract directly from the source handler""" lower_features = [f.lower() for f in self.handle_features] - return [f for f in self.raw_features - if self.lookup(f, 'compute') is None - or Feature.get_basename(f.lower()) in lower_features] + return [ + f + for f in self.raw_features + if self.lookup(f, 'compute') is None + or Feature.get_basename(f.lower()) in lower_features + ] @property def derive_features(self): """List of features which need to be derived from other features""" - derive_features = [f for f in set(list(self.noncached_features) - + list(self.extract_features)) - if f not in self.extract_features] + derive_features = [ + f + for f in set( + list(self.noncached_features) + list(self.extract_features) + ) + if f not in self.extract_features + ] return derive_features @property @@ -1072,7 +1180,8 @@ def raw_features(self): """Get list of features needed for computations""" if self._raw_features is None: self._raw_features = self.get_raw_feature_list( - self.noncached_features, self.handle_features) + self.noncached_features, self.handle_features + ) return self._raw_features @property @@ -1081,8 +1190,10 @@ def output_features(self): corresponding to the features in the high res batch array.""" out = [] for feature in self.features: - ignore = any(fnmatch(feature.lower(), pattern.lower()) - for pattern in self.train_only_features) + ignore = any( + fnmatch(feature.lower(), pattern.lower()) + for pattern in self.train_only_features + ) if not ignore: out.append(feature) return out @@ -1119,52 +1230,67 @@ def preflight(self): self.cap_worker_args(self.max_workers) if len(self.sample_shape) == 2: - logger.info('Found 2D sample shape of {}. Adding temporal dim of 1' - .format(self.sample_shape)) + logger.info( + 'Found 2D sample shape of {}. Adding temporal dim of 1'.format( + self.sample_shape + ) + ) self.sample_shape = (*self.sample_shape, 1) start = self.temporal_slice.start stop = self.temporal_slice.stop n_steps = self.n_tsteps - msg = (f'Temporal slice step ({self.temporal_slice.step}) does not ' - f'evenly divide the number of time steps ({n_steps})') + msg = ( + f'Temporal slice step ({self.temporal_slice.step}) does not ' + f'evenly divide the number of time steps ({n_steps})' + ) check = self.temporal_slice.step is None check = check or n_steps % self.temporal_slice.step == 0 if not check: logger.warning(msg) warnings.warn(msg) - msg = (f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' - 'than the number of time steps in the raw data ' - f'({len(self.raw_time_index)}).') + msg = ( + f'sample_shape[2] ({self.sample_shape[2]}) cannot be larger ' + 'than the number of time steps in the raw data ' + f'({len(self.raw_time_index)}).' + ) if len(self.raw_time_index) < self.sample_shape[2]: logger.warning(msg) warnings.warn(msg) - msg = (f'The requested time slice {self.temporal_slice} conflicts ' - f'with the number of time steps ({len(self.raw_time_index)}) ' - 'in the raw data') - t_slice_is_subset = (start is not None and stop is not None) - good_subset = (t_slice_is_subset - and (stop - start <= len(self.raw_time_index)) - and stop <= len(self.raw_time_index) - and start <= len(self.raw_time_index)) + msg = ( + f'The requested time slice {self.temporal_slice} conflicts ' + f'with the number of time steps ({len(self.raw_time_index)}) ' + 'in the raw data' + ) + t_slice_is_subset = start is not None and stop is not None + good_subset = ( + t_slice_is_subset + and (stop - start <= len(self.raw_time_index)) + and stop <= len(self.raw_time_index) + and start <= len(self.raw_time_index) + ) if t_slice_is_subset and not good_subset: logger.error(msg) raise RuntimeError(msg) - msg = (f'Initializing DataHandler {self.input_file_info}. ' - f'Getting temporal range {self.time_index[0]!s} to ' - f'{self.time_index[-1]!s} (inclusive) ' - f'based on temporal_slice {self.temporal_slice}') + msg = ( + f'Initializing DataHandler {self.input_file_info}. ' + f'Getting temporal range {self.time_index[0]!s} to ' + f'{self.time_index[-1]!s} (inclusive) ' + f'based on temporal_slice {self.temporal_slice}' + ) logger.info(msg) - logger.info(f'Using max_workers={self.max_workers}, ' - f'norm_workers={self.norm_workers}, ' - f'extract_workers={self.extract_workers}, ' - f'compute_workers={self.compute_workers}, ' - f'load_workers={self.load_workers}, ' - f'ti_workers={self.ti_workers}') + logger.info( + f'Using max_workers={self.max_workers}, ' + f'norm_workers={self.norm_workers}, ' + f'extract_workers={self.extract_workers}, ' + f'compute_workers={self.compute_workers}, ' + f'load_workers={self.load_workers}, ' + f'ti_workers={self.ti_workers}' + ) @classmethod def get_lat_lon(cls, file_paths, raster_index, invert_lat=False): @@ -1205,36 +1331,40 @@ def get_node_cmd(cls, config): initialize DataHandler and run data extraction. """ - import_str = ('from sup3r.preprocessing.data_handling ' - f'import {cls.__name__};\n' - 'import time;\n' - 'from reV.pipeline.status import Status;\n' - 'from rex import init_logger;\n') + import_str = ( + 'from sup3r.preprocessing.data_handling ' + f'import {cls.__name__};\n' + 'import time;\n' + 'from reV.pipeline.status import Status;\n' + 'from rex import init_logger;\n' + ) dh_init_str = get_fun_call_str(cls, config) log_file = config.get('log_file', None) log_level = config.get('log_level', 'INFO') - log_arg_str = (f'"sup3r", log_level="{log_level}"') + log_arg_str = f'"sup3r", log_level="{log_level}"' if log_file is not None: log_arg_str += f', log_file="{log_file}"' cache_check = config.get('cache_pattern', False) - msg = ('No cache file prefix provided.') + msg = 'No cache file prefix provided.' if not cache_check: logger.warning(msg) warnings.warn(msg) - cmd = (f"python -c \'{import_str}\n" - "t0 = time.time();\n" - f"logger = init_logger({log_arg_str});\n" - f"data_handler = {dh_init_str};\n" - "t_elap = time.time() - t0;\n") + cmd = ( + f"python -c \'{import_str}\n" + "t0 = time.time();\n" + f"logger = init_logger({log_arg_str});\n" + f"data_handler = {dh_init_str};\n" + "t_elap = time.time() - t0;\n" + ) cmd = BaseCLI.add_status_cmd(config, ModuleName.DATA_EXTRACT, cmd) - cmd += (";\'\n") + cmd += ";\'\n" return cmd.replace('\\', '/') def get_cache_file_names(self, cache_pattern): @@ -1251,8 +1381,10 @@ def get_cache_file_names(self, cache_pattern): List of cache file names """ if cache_pattern is not None: - cache_files = [cache_pattern.replace('{feature}', f.lower()) - for f in self.features] + cache_files = [ + cache_pattern.replace('{feature}', f.lower()) + for f in self.features + ] for i, f in enumerate(cache_files): if '{shape}' in f: shape = f'{self.grid_shape[0]}x{self.grid_shape[1]}' @@ -1323,19 +1455,24 @@ def parallel_normalization(self, means, stds, max_workers=None): future = exe.submit(self._normalize_data, i, means[i], stds[i]) futures[future] = i - logger.info(f'Started normalizing {self.shape[-1]} features ' - f'in {dt.now() - now}.') + logger.info( + f'Started normalizing {self.shape[-1]} features ' + f'in {dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ('Error while normalizing future number ' - f'{futures[future]}.') + msg = ( + 'Error while normalizing future number ' + f'{futures[future]}.' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of {self.shape[-1]} features ' - 'normalized.') + logger.debug( + f'{i+1} out of {self.shape[-1]} features ' 'normalized.' + ) def _normalize_data(self, feature_index, mean, std): """Normalize data with initialized mean and standard deviation for a @@ -1360,8 +1497,10 @@ def _normalize_data(self, feature_index, mean, std): self.val_data[..., feature_index] /= std self.data[..., feature_index] /= std else: - msg = ('Standard Deviation is zero for ' - f'{self.features[feature_index]}') + msg = ( + 'Standard Deviation is zero for ' + f'{self.features[feature_index]}' + ) logger.warning(msg) warnings.warn(msg) @@ -1377,7 +1516,8 @@ def get_observation_index(self): spatial_slice = uniform_box_sampler(self.data, self.sample_shape[:2]) temporal_slice = uniform_time_sampler(self.data, self.sample_shape[2]) return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, np.arange(len(self.features))] + ) def get_next(self): """Get data for observation using random observation index. Loops @@ -1428,8 +1568,9 @@ def split_data(self, data=None): training_indices = all_indices[n_val_obs:] if not self.shuffle_time: - [self.val_data, self.data] = np.split(self.data, [n_val_obs], - axis=2) + [self.val_data, self.data] = np.split( + self.data, [n_val_obs], axis=2 + ) else: self.val_data = self.data[:, :, val_indices, :] self.data = self.data[:, :, training_indices, :] @@ -1463,19 +1604,25 @@ def cache_data(self, cache_file_paths): for i, fp in enumerate(cache_file_paths): if not os.path.exists(fp) or self.overwrite_cache: if self.overwrite_cache and os.path.exists(fp): - logger.info(f'Overwriting {self.features[i]} with shape ' - f'{self.data[..., i].shape} to {fp}') + logger.info( + f'Overwriting {self.features[i]} with shape ' + f'{self.data[..., i].shape} to {fp}' + ) else: - logger.info(f'Saving {self.features[i]} with shape ' - f'{self.data[..., i].shape} to {fp}') + logger.info( + f'Saving {self.features[i]} with shape ' + f'{self.data[..., i].shape} to {fp}' + ) tmp_file = fp.replace('.pkl', '.pkl.tmp') with open(tmp_file, 'wb') as fh: pickle.dump(self.data[..., i], fh, protocol=4) os.replace(tmp_file, fp) else: - msg = (f'Called cache_data but {fp} already exists. Set to ' - 'overwrite_cache to True to overwrite.') + msg = ( + f'Called cache_data but {fp} already exists. Set to ' + 'overwrite_cache to True to overwrite.' + ) logger.warning(msg) warnings.warn(msg) @@ -1496,19 +1643,25 @@ def parallel_load(self, max_workers=None): future = exe.submit(self.load_single_cached_feature, fp=fp) futures[future] = {'idx': i, 'fp': os.path.basename(fp)} - logger.info(f'Started loading all {len(self.cache_files)} cache ' - f'files in {dt.now() - now}.') + logger.info( + f'Started loading all {len(self.cache_files)} cache ' + f'files in {dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = ('Error while loading ' - f'{self.cache_files[futures[future]["idx"]]}') + msg = ( + 'Error while loading ' + f'{self.cache_files[futures[future]["idx"]]}' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug(f'{i+1} out of {len(futures)} cache files ' - f'loaded: {futures[future]["fp"]}') + logger.debug( + f'{i+1} out of {len(futures)} cache files ' + f'loaded: {futures[future]["fp"]}' + ) def load_single_cached_feature(self, fp): """Load single feature from given file @@ -1526,20 +1679,24 @@ def load_single_cached_feature(self, fp): idx = self.cache_files.index(fp) assert self.features[idx].lower() in fp.lower() fp = ignore_case_path_fetch(fp) - logger.info(f'Loading {self.features[idx]} from ' - f'{os.path.basename(fp)}') + logger.info( + f'Loading {self.features[idx]} from ' f'{os.path.basename(fp)}' + ) with open(fp, 'rb') as fh: try: - self.data[..., idx] = np.array(pickle.load(fh), - dtype=np.float32) + self.data[..., idx] = np.array( + pickle.load(fh), dtype=np.float32 + ) except Exception as e: - msg = ('Data loaded from from cache file "{}" ' - 'could not be written to feature channel {} ' - 'of full data array of shape {}. ' - 'The cached data has the wrong shape {}.' - .format(fp, idx, self.data.shape, - pickle.load(fh).shape)) + msg = ( + 'Data loaded from from cache file "{}" ' + 'could not be written to feature channel {} ' + 'of full data array of shape {}. ' + 'The cached data has the wrong shape {}.'.format( + fp, idx, self.data.shape, pickle.load(fh).shape + ) + ) raise RuntimeError(msg) from e def load_cached_data(self): @@ -1549,19 +1706,27 @@ def load_cached_data(self): elif self.data is None: shape = get_raster_shape(self.raster_index) - requested_shape = (shape[0] // self.hr_spatial_coarsen, - shape[1] // self.hr_spatial_coarsen, - len(self.time_index), - len(self.features)) - - msg = ('Found {} cache files but need {} for features {}! ' - 'These are the cache files that were found: {}' - .format(len(self.cache_files), len(self.features), - self.features, self.cache_files)) + requested_shape = ( + shape[0] // self.hr_spatial_coarsen, + shape[1] // self.hr_spatial_coarsen, + len(self.time_index), + len(self.features), + ) + + msg = ( + 'Found {} cache files but need {} for features {}! ' + 'These are the cache files that were found: {}'.format( + len(self.cache_files), + len(self.features), + self.features, + self.cache_files, + ) + ) assert len(self.cache_files) == len(self.features), msg - self.data = np.full(shape=requested_shape, fill_value=np.nan, - dtype=np.float32) + self.data = np.full( + shape=requested_shape, fill_value=np.nan, dtype=np.float32 + ) logger.info(f'Loading cached data from: {self.cache_files}') max_workers = self.load_workers @@ -1571,20 +1736,27 @@ def load_cached_data(self): else: self.parallel_load(max_workers=max_workers) - nan_perc = (100 * np.isnan(self.data).sum() / self.data.size) + nan_perc = 100 * np.isnan(self.data).sum() / self.data.size if nan_perc > 0: - msg = ('Data has {:.2f}% NaN values!'.format(nan_perc)) + msg = 'Data has {:.2f}% NaN values!'.format(nan_perc) logger.warning(msg) warnings.warn(msg) - logger.debug('Splitting data into training / validation sets ' - f'({1 - self.val_split}, {self.val_split}) ' - f'for {self.input_file_info}') + logger.debug( + 'Splitting data into training / validation sets ' + f'({1 - self.val_split}, {self.val_split}) ' + f'for {self.input_file_info}' + ) self.data, self.val_data = self.split_data() @classmethod - def check_cached_features(cls, features, cache_files=None, - overwrite_cache=False, load_cached=False): + def check_cached_features( + cls, + features, + cache_files=None, + overwrite_cache=False, + load_cached=False, + ): """Check which features have been cached and check flags to determine whether to load or extract this features again @@ -1609,23 +1781,31 @@ def check_cached_features(cls, features, cache_files=None, # check if any features can be loaded from cache if cache_files is not None: for i, f in enumerate(features): - check = (os.path.exists(cache_files[i]) - and f.lower() in cache_files[i].lower()) + check = ( + os.path.exists(cache_files[i]) + and f.lower() in cache_files[i].lower() + ) if check: if not overwrite_cache: if load_cached: - msg = (f'{f} found in cache file {cache_files[i]}.' - ' Loading from cache instead of extracting ' - 'from source files') + msg = ( + f'{f} found in cache file {cache_files[i]}.' + ' Loading from cache instead of extracting ' + 'from source files' + ) logger.info(msg) else: - msg = (f'{f} found in cache file {cache_files[i]}.' - ' Call load_cached_data() or use ' - 'load_cached=True to load this data.') + msg = ( + f'{f} found in cache file {cache_files[i]}.' + ' Call load_cached_data() or use ' + 'load_cached=True to load this data.' + ) logger.info(msg) else: - msg = (f'{cache_files[i]} exists but overwrite_cache ' - 'is set to True. Proceeding with extraction.') + msg = ( + f'{cache_files[i]} exists but overwrite_cache ' + 'is set to True. Proceeding with extraction.' + ) logger.info(msg) extract_features.append(f) else: @@ -1654,8 +1834,9 @@ def run_all_data_init(self): shifted_time_chunks = [slice(None)] else: n_steps = len(self.raw_time_index[self.temporal_slice]) - shifted_time_chunks = get_chunk_slices(n_steps, - self.time_chunk_size) + shifted_time_chunks = get_chunk_slices( + n_steps, self.time_chunk_size + ) self.run_data_extraction() self.run_data_compute() @@ -1672,9 +1853,9 @@ def run_all_data_init(self): if self.hr_spatial_coarsen > 1: logger.debug('Applying hr spatial coarsening to data array') - self.data = spatial_coarsening(self.data, - s_enhance=self.hr_spatial_coarsen, - obs_axis=False) + self.data = spatial_coarsening( + self.data, s_enhance=self.hr_spatial_coarsen, obs_axis=False + ) if self.load_cached: for f in self.cached_features: f_index = self.features.index(f) @@ -1682,9 +1863,11 @@ def run_all_data_init(self): with open(self.cache_files[f_index], 'rb') as fh: self.data[..., f_index] = pickle.load(fh) - logger.info('Finished extracting data for ' - f'{self.input_file_info} in ' - f'{dt.now() - now}') + logger.info( + 'Finished extracting data for ' + f'{self.input_file_info} in ' + f'{dt.now() - now}' + ) return self.data def run_data_extraction(self): @@ -1692,25 +1875,33 @@ def run_data_extraction(self): un-manipulated datasets. """ if self.extract_features: - logger.info(f'Starting extraction of {self.extract_features} ' - f'using {len(self.time_chunks)} time_chunks.') + logger.info( + f'Starting extraction of {self.extract_features} ' + f'using {len(self.time_chunks)} time_chunks.' + ) if self.extract_workers == 1: - self._raw_data = self.serial_extract(self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - **self.res_kwargs) + self._raw_data = self.serial_extract( + self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + **self.res_kwargs, + ) else: - self._raw_data = self.parallel_extract(self.file_paths, - self.raster_index, - self.time_chunks, - self.extract_features, - self.extract_workers, - **self.res_kwargs) - - logger.info(f'Finished extracting {self.extract_features} for ' - f'{self.input_file_info}') + self._raw_data = self.parallel_extract( + self.file_paths, + self.raster_index, + self.time_chunks, + self.extract_features, + self.extract_workers, + **self.res_kwargs, + ) + + logger.info( + f'Finished extracting {self.extract_features} for ' + f'{self.input_file_info}' + ) def run_data_compute(self): """Run the data computation / derivation from raw features to desired @@ -1720,26 +1911,32 @@ def run_data_compute(self): logger.info(f'Starting computation of {self.derive_features}') if self.compute_workers == 1: - self._raw_data = self.serial_compute(self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features) + self._raw_data = self.serial_compute( + self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features, + ) elif self.compute_workers != 1: - self._raw_data = self.parallel_compute(self._raw_data, - self.file_paths, - self.raster_index, - self.time_chunks, - self.derive_features, - self.noncached_features, - self.handle_features, - self.compute_workers) - - logger.info(f'Finished computing {self.derive_features} for ' - f'{self.input_file_info}') + self._raw_data = self.parallel_compute( + self._raw_data, + self.file_paths, + self.raster_index, + self.time_chunks, + self.derive_features, + self.noncached_features, + self.handle_features, + self.compute_workers, + ) + + logger.info( + f'Finished computing {self.derive_features} for ' + f'{self.input_file_info}' + ) def data_fill(self, t, t_slice, f_index, f): """Place single extracted / computed chunk in final data array @@ -1775,8 +1972,10 @@ def serial_data_fill(self, shifted_time_chunks): self.data_fill(t, ts, f_index, f) interval = int(np.ceil(len(shifted_time_chunks) / 10)) if t % interval == 0: - logger.info(f'Added {t + 1} of {len(shifted_time_chunks)} ' - 'chunks to final data array') + logger.info( + f'Added {t + 1} of {len(shifted_time_chunks)} ' + 'chunks to final data array' + ) self._raw_data.pop(t) def parallel_data_fill(self, shifted_time_chunks, max_workers=None): @@ -1792,9 +1991,15 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): max available workers will be used. If 1 cached data will be loaded in serial """ - self.data = np.zeros((self.grid_shape[0], self.grid_shape[1], - self.n_tsteps, len(self.features)), - dtype=np.float32) + self.data = np.zeros( + ( + self.grid_shape[0], + self.grid_shape[1], + self.n_tsteps, + len(self.features), + ), + dtype=np.float32, + ) if max_workers == 1: self.serial_data_fill(shifted_time_chunks) @@ -1808,22 +2013,28 @@ def parallel_data_fill(self, shifted_time_chunks, max_workers=None): future = exe.submit(self.data_fill, t, ts, f_index, f) futures[future] = {'t': t, 'fidx': f_index} - logger.info(f'Started adding {len(futures)} chunks ' - f'to data array in {dt.now() - now}.') + logger.info( + f'Started adding {len(futures)} chunks ' + f'to data array in {dt.now() - now}.' + ) interval = int(np.ceil(len(futures) / 10)) for i, future in enumerate(as_completed(futures)): try: future.result() except Exception as e: - msg = (f'Error adding ({futures[future]["t"]}, ' - f'{futures[future]["fidx"]}) chunk to ' - 'final data array.') + msg = ( + f'Error adding ({futures[future]["t"]}, ' + f'{futures[future]["fidx"]}) chunk to ' + 'final data array.' + ) logger.exception(msg) raise RuntimeError(msg) from e if i % interval == 0: - logger.debug(f'Added {i+1} out of {len(futures)} ' - 'chunks to final data array') + logger.debug( + f'Added {i+1} out of {len(futures)} ' + 'chunks to final data array' + ) logger.info('Finished building data array') @abstractmethod @@ -1864,66 +2075,54 @@ def lin_bc(self, bc_files, threshold=0.1): completed = [] for idf, feature in enumerate(self.features): - dset_scalar = f'{feature}_scalar' - dset_adder = f'{feature}_adder' for fp in bc_files: - with Resource(fp) as res: - lat = np.expand_dims(res['latitude'], axis=-1) - lon = np.expand_dims(res['longitude'], axis=-1) - lat_lon_bc = np.dstack((lat, lon)) - lat_lon_0 = self.lat_lon[:1, :1] - diff = lat_lon_bc - lat_lon_0 - diff = np.hypot(diff[..., 0], diff[..., 1]) - idy, idx = np.where(diff == diff.min()) - slice_y = slice(idy[0], idy[0] + self.shape[0]) - slice_x = slice(idx[0], idx[0] + self.shape[1]) - - if diff.min() > threshold: - msg = ('The DataHandler top left coordinate of {} ' - 'appears to be {} away from the nearest ' - 'bias correction coordinate of {} from {}. ' - 'Cannot apply bias correction.' - .format(lat_lon_0, diff.min(), - lat_lon_bc[idy, idx], - os.path.basename(fp))) + if feature not in completed: + scalar, adder = get_spatial_bc_factors( + lat_lon=self.lat_lon, + feature_name=feature, + bias_fp=fp, + threshold=threshold, + ) + + if scalar.shape[-1] == 1: + scalar = np.repeat(scalar, self.shape[2], axis=2) + adder = np.repeat(adder, self.shape[2], axis=2) + elif scalar.shape[-1] == 12: + idm = self.time_index.month.values - 1 + scalar = scalar[..., idm] + adder = adder[..., idm] + else: + msg = ( + 'Can only accept bias correction factors ' + 'with last dim equal to 1 or 12 but ' + 'received bias correction factors with ' + 'shape {}'.format(scalar.shape) + ) logger.error(msg) raise RuntimeError(msg) - check = (dset_scalar in res.dsets - and dset_adder in res.dsets - and feature not in completed) - if check: - scalar = res[dset_scalar, slice_y, slice_x] - adder = res[dset_adder, slice_y, slice_x] - - if scalar.shape[-1] == 1: - scalar = np.repeat(scalar, self.shape[2], axis=2) - adder = np.repeat(adder, self.shape[2], axis=2) - elif scalar.shape[-1] == 12: - idm = self.time_index.month.values - 1 - scalar = scalar[..., idm] - adder = adder[..., idm] - else: - msg = ('Can only accept bias correction factors ' - 'with last dim equal to 1 or 12 but ' - 'received bias correction factors with ' - 'shape {}'.format(scalar.shape)) - logger.error(msg) - raise RuntimeError(msg) - - logger.info('Bias correcting "{}" with linear ' - 'correction from "{}"' - .format(feature, os.path.basename(fp))) - self.data[..., idf] *= scalar - self.data[..., idf] += adder - completed.append(feature) + logger.info( + 'Bias correcting "{}" with linear ' + 'correction from "{}"'.format( + feature, os.path.basename(fp) + ) + ) + self.data[..., idf] *= scalar + self.data[..., idf] += adder + completed.append(feature) class DataHandlerNC(DataHandler): """Data Handler for NETCDF data""" - CHUNKS = {'XTIME': 100, 'XLAT': 150, 'XLON': 150, - 'south_north': 150, 'west_east': 150, 'Time': 100} + CHUNKS: ClassVar[dict] = { + 'XTIME': 100, + 'XLAT': 150, + 'XLON': 150, + 'south_north': 150, + 'west_east': 150, + 'Time': 100, + } """CHUNKS sets the chunk sizes to extract from the data in each dimension. Chunk sizes that approximately match the data volume being extracted typically results in the most efficient IO.""" @@ -1958,8 +2157,9 @@ def extract_workers(self): proc_mem /= len(self.time_chunks) n_procs = len(self.time_chunks) * len(self.extract_features) n_procs = int(np.ceil(n_procs)) - extract_workers = estimate_max_workers(self._extract_workers, proc_mem, - n_procs) + extract_workers = estimate_max_workers( + self._extract_workers, proc_mem, n_procs + ) return extract_workers @classmethod @@ -1985,8 +2185,11 @@ def source_handler(cls, file_paths, **kwargs): data : xarray.Dataset """ time_key = get_time_dim_name(file_paths[0]) - default_kws = {'combine': 'nested', 'concat_dim': time_key, - 'chunks': cls.CHUNKS} + default_kws = { + 'combine': 'nested', + 'concat_dim': time_key, + 'chunks': cls.CHUNKS, + } default_kws.update(kwargs) return xr.open_mfdataset(file_paths, **default_kws) @@ -2020,8 +2223,10 @@ def get_file_times(cls, file_paths, **kwargs): elif hasattr(handle, 'times'): time_index = np_to_pd_times(handle.times.values) else: - msg = (f'Could not get time_index for {file_paths}. ' - 'Assuming time independence.') + msg = ( + f'Could not get time_index for {file_paths}. ' + 'Assuming time independence.' + ) time_index = None logger.warning(msg) warnings.warn(msg) @@ -2049,8 +2254,11 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): time_index : pd.Datetimeindex List of times as a Datetimeindex """ - max_workers = (len(file_paths) if max_workers is None - else np.min((max_workers, len(file_paths)))) + max_workers = ( + len(file_paths) + if max_workers is None + else np.min((max_workers, len(file_paths))) + ) if max_workers == 1: return cls.get_file_times(file_paths, **kwargs) ti = {} @@ -2061,8 +2269,10 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): future = exe.submit(cls.get_file_times, [f], **kwargs) futures[future] = {'idx': i, 'file': f} - logger.info(f'Started building time index from {len(file_paths)} ' - f'files in {dt.now() - now}.') + logger.info( + f'Started building time index from {len(file_paths)} ' + f'files in {dt.now() - now}.' + ) for i, future in enumerate(as_completed(futures)): try: @@ -2070,12 +2280,13 @@ def get_time_index(cls, file_paths, max_workers=None, **kwargs): if val is not None: ti[futures[future]['idx']] = list(val) except Exception as e: - msg = ('Error while getting time index from file ' - f'{futures[future]["file"]}.') + msg = ( + 'Error while getting time index from file ' + f'{futures[future]["file"]}.' + ) logger.exception(msg) raise RuntimeError(msg) from e - logger.debug( - f'Stored {i+1} out of {len(futures)} file times') + logger.debug(f'Stored {i+1} out of {len(futures)} file times') times = np.concatenate(list(ti.values())) return pd.DatetimeIndex(sorted(set(times))) @@ -2103,12 +2314,19 @@ def feature_registry(cls): 'Pressure_(.*)m': PressureNC, 'PotentialTemp_(.*)m': PotentialTempNC, 'PT_(.*)m': PotentialTempNC, - 'topography': 'HGT'} + 'topography': 'HGT', + } return registry @classmethod - def extract_feature(cls, file_paths, raster_index, feature, - time_slice=slice(None), **kwargs): + def extract_feature( + cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs, + ): """Extract single feature from data source. The requested feature can match exactly to one found in the source data or can have a matching prefix with a suffix specifying the height or pressure level @@ -2137,8 +2355,10 @@ def extract_feature(cls, file_paths, raster_index, feature, Data array for extracted feature (spatial_1, spatial_2, temporal) """ - logger.debug(f'Extracting {feature} with time_slice={time_slice}, ' - f'raster_index={raster_index}, kwargs={kwargs}.') + logger.debug( + f'Extracting {feature} with time_slice={time_slice}, ' + f'raster_index={raster_index}, kwargs={kwargs}.' + ) handle = cls.source_handler(file_paths, **kwargs) f_info = Feature(feature, handle) interp_height = f_info.height @@ -2147,19 +2367,28 @@ def extract_feature(cls, file_paths, raster_index, feature, if feature in handle or feature.lower() in handle: feat_key = feature if feature in handle else feature.lower() - fdata = cls.direct_extract(handle, feat_key, raster_index, - time_slice) + fdata = cls.direct_extract( + handle, feat_key, raster_index, time_slice + ) elif basename in handle or basename.lower() in handle: feat_key = basename if basename in handle else basename.lower() if interp_height is not None: fdata = Interpolator.interp_var_to_height( - handle, feat_key, raster_index, np.float32(interp_height), - time_slice) + handle, + feat_key, + raster_index, + np.float32(interp_height), + time_slice, + ) elif interp_pressure is not None: fdata = Interpolator.interp_var_to_pressure( - handle, feat_key, raster_index, - np.float32(interp_pressure), time_slice) + handle, + feat_key, + raster_index, + np.float32(interp_pressure), + time_slice, + ) else: msg = f'{feature} cannot be extracted from source data.' @@ -2244,12 +2473,14 @@ def get_closest_lat_lon(lat_lon, target): col index for closest lat/lon to target lat/lon """ # shape of ll2 is (n, 2) where axis=1 is (lat, lon) - ll2 = np.vstack((lat_lon[..., 0].flatten(), - lat_lon[..., 1].flatten())).T + ll2 = np.vstack( + (lat_lon[..., 0].flatten(), lat_lon[..., 1].flatten()) + ).T tree = KDTree(ll2) _, i = tree.query(np.array(target)) - row, col = np.where((lat_lon[..., 0] == ll2[i, 0]) - & (lat_lon[..., 1] == ll2[i, 1])) + row, col = np.where( + (lat_lon[..., 0] == ll2[i, 0]) & (lat_lon[..., 1] == ll2[i, 1]) + ) row = row[0] col = col[0] return row, col @@ -2272,8 +2503,9 @@ def compute_raster_index(cls, file_paths, target, grid_shape): list List of slices corresponding to extracted data region """ - lat_lon = cls.get_lat_lon(file_paths[:1], [slice(None), slice(None)], - invert_lat=False) + lat_lon = cls.get_lat_lon( + file_paths[:1], [slice(None), slice(None)], invert_lat=False + ) cls._check_grid_extent(target, grid_shape, lat_lon) row, col = cls.get_closest_lat_lon(lat_lon, target) @@ -2291,8 +2523,10 @@ def compute_raster_index(cls, file_paths, target, grid_shape): else: row_end = row + grid_shape[0] row_start = row - raster_index = [slice(row_start, row_end), - slice(col, col + grid_shape[1])] + raster_index = [ + slice(row_start, row_end), + slice(col, col + grid_shape[1]), + ] cls._validate_raster_shape(target, grid_shape, lat_lon, raster_index) return raster_index @@ -2316,14 +2550,20 @@ def _check_grid_extent(cls, target, grid_shape, lat_lon): min_lon = np.min(lat_lon[..., 1]) max_lat = np.max(lat_lon[..., 0]) max_lon = np.max(lat_lon[..., 1]) - logger.debug('Calculating raster index from WRF file ' - f'for shape {grid_shape} and target {target}') - logger.debug(f'lat/lon (min, max): {min_lat}/{min_lon}, ' - f'{max_lat}/{max_lon}') - msg = (f'target {target} out of bounds with min lat/lon ' - f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}') - assert (min_lat <= target[0] <= max_lat - and min_lon <= target[1] <= max_lon), msg + logger.debug( + 'Calculating raster index from WRF file ' + f'for shape {grid_shape} and target {target}' + ) + logger.debug( + f'lat/lon (min, max): {min_lat}/{min_lon}, ' f'{max_lat}/{max_lon}' + ) + msg = ( + f'target {target} out of bounds with min lat/lon ' + f'{min_lat}/{min_lon} and max lat/lon {max_lat}/{max_lon}' + ) + assert ( + min_lat <= target[0] <= max_lat and min_lon <= target[1] <= max_lon + ), msg @classmethod def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): @@ -2343,16 +2583,20 @@ def _validate_raster_shape(cls, target, grid_shape, lat_lon, raster_index): raster_index : list List of slices selecting region from entire available grid. """ - if (raster_index[0].stop > lat_lon.shape[0] - or raster_index[1].stop > lat_lon.shape[1] - or raster_index[0].start < 0 - or raster_index[1].start < 0): - msg = (f'Invalid target {target}, shape {grid_shape}, and raster ' - f'{raster_index} for data domain of size ' - f'{lat_lon.shape[:-1]} with lower left corner ' - f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' - f' and upper right corner ({np.max(lat_lon[..., 0])}, ' - f'{np.max(lat_lon[..., 1])}).') + if ( + raster_index[0].stop > lat_lon.shape[0] + or raster_index[1].stop > lat_lon.shape[1] + or raster_index[0].start < 0 + or raster_index[1].start < 0 + ): + msg = ( + f'Invalid target {target}, shape {grid_shape}, and raster ' + f'{raster_index} for data domain of size ' + f'{lat_lon.shape[:-1]} with lower left corner ' + f'({np.min(lat_lon[..., 0])}, {np.min(lat_lon[..., 1])}) ' + f' and upper right corner ({np.max(lat_lon[..., 0])}, ' + f'{np.max(lat_lon[..., 1])}).' + ) raise ValueError(msg) def get_raster_index(self): @@ -2365,23 +2609,33 @@ def get_raster_index(self): raster_index : np.ndarray 2D array of grid indices """ - self.raster_file = (self.raster_file if self.raster_file is None - else self.raster_file.replace('.txt', '.npy')) + self.raster_file = ( + self.raster_file + if self.raster_file is None + else self.raster_file.replace('.txt', '.npy') + ) if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug(f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}') + logger.debug( + f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}' + ) raster_index = np.load(self.raster_file, allow_pickle=True) raster_index = list(raster_index) else: - check = (self.grid_shape is not None and self.target is not None) - msg = ('Must provide raster file or shape + target to get ' - 'raster index') + check = self.grid_shape is not None and self.target is not None + msg = ( + 'Must provide raster file or shape + target to get ' + 'raster index' + ) assert check, msg - raster_index = self.compute_raster_index(self.file_paths, - self.target, - self.grid_shape) - logger.debug('Found raster index with row, col slices: {}' - .format(raster_index)) + raster_index = self.compute_raster_index( + self.file_paths, self.target, self.grid_shape + ) + logger.debug( + 'Found raster index with row, col slices: {}'.format( + raster_index + ) + ) if self.raster_file is not None: basedir = os.path.dirname(self.raster_file) @@ -2396,13 +2650,19 @@ def get_raster_index(self): class DataHandlerNCforCC(DataHandlerNC): """Data Handler for NETCDF climate change data""" - CHUNKS = {'time': 5, 'lat': 20, 'lon': 20} + CHUNKS: ClassVar[dict] = {'time': 5, 'lat': 20, 'lon': 20} """CHUNKS sets the chunk sizes to extract from the data in each dimension. Chunk sizes that approximately match the data volume being extracted typically results in the most efficient IO.""" - def __init__(self, *args, nsrdb_source_fp=None, nsrdb_agg=1, - nsrdb_smoothing=0, **kwargs): + def __init__( + self, + *args, + nsrdb_source_fp=None, + nsrdb_agg=1, + nsrdb_smoothing=0, + **kwargs, + ): """Initialize NETCDF data handler for climate change data. Parameters @@ -2456,7 +2716,8 @@ def feature_registry(cls): 'Temperature_(.*)': TempNCforCC, 'temperature_2m': Tas, 'temperature_max_2m': TasMax, - 'temperature_min_2m': TasMin} + 'temperature_min_2m': TasMin, + } return registry @classmethod @@ -2505,8 +2766,10 @@ def run_data_extraction(self): # clearsky ghi is extracted at the proper starting time index so # the time chunks should start at 0 tc0 = self.time_chunks[0].start - cs_ghi_time_chunks = [slice(tc.start - tc0, tc.stop - tc0, tc.step) - for tc in self.time_chunks] + cs_ghi_time_chunks = [ + slice(tc.start - tc0, tc.stop - tc0, tc.step) + for tc in self.time_chunks + ] for it, tslice in enumerate(cs_ghi_time_chunks): self._raw_data[it]['clearsky_ghi'] = cs_ghi[..., tslice] @@ -2523,22 +2786,30 @@ def get_clearsky_ghi(self): shape is (lat, lon, time) where time is daily average values. """ - msg = ('Need nsrdb_source_fp input arg as a valid filepath to ' - 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' - 'received: {}'.format(self._nsrdb_source_fp)) + msg = ( + 'Need nsrdb_source_fp input arg as a valid filepath to ' + 'retrieve clearsky_ghi (maybe for clearsky_ratio) but ' + 'received: {}'.format(self._nsrdb_source_fp) + ) assert self._nsrdb_source_fp is not None, msg assert os.path.exists(self._nsrdb_source_fp), msg - msg = ('Can only handle source CC data in hourly frequency but ' - 'received daily frequency of {}hrs (should be 24) ' - 'with raw time index: {}' - .format(self.time_freq_hours, self.raw_time_index)) + msg = ( + 'Can only handle source CC data in hourly frequency but ' + 'received daily frequency of {}hrs (should be 24) ' + 'with raw time index: {}'.format( + self.time_freq_hours, self.raw_time_index + ) + ) assert self.time_freq_hours == 24.0, msg - msg = ('Can only handle source CC data with temporal_slice.step == 1 ' - 'but received: {}'.format(self.temporal_slice.step)) - assert ((self.temporal_slice.step is None) - | (self.temporal_slice.step == 1)), msg + msg = ( + 'Can only handle source CC data with temporal_slice.step == 1 ' + 'but received: {}'.format(self.temporal_slice.step) + ) + assert (self.temporal_slice.step is None) | ( + self.temporal_slice.step == 1 + ), msg with Resource(self._nsrdb_source_fp) as res: ti_nsrdb = res.time_index @@ -2564,10 +2835,15 @@ def get_clearsky_ghi(self): if len(i.shape) == 1: i = np.expand_dims(i, axis=1) - logger.info('Extracting clearsky_ghi data from "{}" with time slice ' - '{} and {} locations with agg factor {}.' - .format(os.path.basename(self._nsrdb_source_fp), - t_slice, i.shape[0], i.shape[1])) + logger.info( + 'Extracting clearsky_ghi data from "{}" with time slice ' + '{} and {} locations with agg factor {}.'.format( + os.path.basename(self._nsrdb_source_fp), + t_slice, + i.shape[0], + i.shape[1], + ) + ) cs_shape = i.shape with Resource(self._nsrdb_source_fp) as res: @@ -2576,8 +2852,9 @@ def get_clearsky_ghi(self): cs_ghi = cs_ghi.reshape((len(cs_ghi), *cs_shape)) cs_ghi = cs_ghi.mean(axis=-1) - windows = np.array_split(np.arange(len(cs_ghi)), - len(cs_ghi) // (24 // time_freq)) + windows = np.array_split( + np.arange(len(cs_ghi)), len(cs_ghi) // (24 // time_freq) + ) cs_ghi = [cs_ghi[window].mean(axis=0) for window in windows] cs_ghi = np.vstack(cs_ghi) cs_ghi = cs_ghi.reshape((len(cs_ghi), *tuple(self.grid_shape))) @@ -2586,23 +2863,28 @@ def get_clearsky_ghi(self): if self.invert_lat: cs_ghi = cs_ghi[::-1] - logger.info('Smoothing nsrdb clearsky ghi with a factor of {}' - .format(self._nsrdb_smoothing)) + logger.info( + 'Smoothing nsrdb clearsky ghi with a factor of {}'.format( + self._nsrdb_smoothing + ) + ) for iday in range(cs_ghi.shape[-1]): - cs_ghi[..., iday] = gaussian_filter(cs_ghi[..., iday], - self._nsrdb_smoothing, - mode='nearest') + cs_ghi[..., iday] = gaussian_filter( + cs_ghi[..., iday], self._nsrdb_smoothing, mode='nearest' + ) if cs_ghi.shape[-1] < t_end_target: n = int(np.ceil(t_end_target / cs_ghi.shape[-1])) cs_ghi = np.repeat(cs_ghi, n, axis=2) cs_ghi = cs_ghi[..., :t_end_target] - logger.info('Reshaped clearsky_ghi data to final shape {} to ' - 'correspond with CC daily average data over source ' - 'temporal_slice {} with (lat, lon) grid shape of {}' - .format(cs_ghi.shape, self.temporal_slice, - self.grid_shape)) + logger.info( + 'Reshaped clearsky_ghi data to final shape {} to ' + 'correspond with CC daily average data over source ' + 'temporal_slice {} with (lat, lon) grid shape of {}'.format( + cs_ghi.shape, self.temporal_slice, self.grid_shape + ) + ) return cs_ghi @@ -2637,8 +2919,10 @@ def source_handler(cls, file_paths, **kwargs): @classmethod def get_full_domain(cls, file_paths): """Get target and shape for largest domain possible""" - msg = ('You must either provide the target+shape inputs or an ' - 'existing raster_file input.') + msg = ( + 'You must either provide the target+shape inputs or an ' + 'existing raster_file input.' + ) logger.error(msg) raise ValueError(msg) @@ -2685,12 +2969,19 @@ def feature_registry(cls): 'P_(.*)m': 'pressure_(.*)m', 'topography': TopoH5, 'cloud_mask': CloudMaskH5, - 'clearsky_ratio': ClearSkyRatioH5} + 'clearsky_ratio': ClearSkyRatioH5, + } return registry @classmethod - def extract_feature(cls, file_paths, raster_index, feature, - time_slice=slice(None), **kwargs): + def extract_feature( + cls, + file_paths, + raster_index, + feature, + time_slice=slice(None), + **kwargs, + ): """Extract single feature from data source Parameters @@ -2715,15 +3006,17 @@ def extract_feature(cls, file_paths, raster_index, feature, logger.info(f'Extracting {feature} with kwargs={kwargs}') handle = cls.source_handler(file_paths, **kwargs) try: - fdata = handle[(feature, time_slice, - *tuple([raster_index.flatten()]))] + fdata = handle[ + (feature, time_slice, *tuple([raster_index.flatten()])) + ] except ValueError as e: msg = f'{feature} cannot be extracted from source data' logger.exception(msg) raise ValueError(msg) from e - fdata = fdata.reshape((-1, raster_index.shape[0], - raster_index.shape[1])) + fdata = fdata.reshape( + (-1, raster_index.shape[0], raster_index.shape[1]) + ) fdata = np.transpose(fdata, (1, 2, 0)) return fdata.astype(np.float32) @@ -2738,21 +3031,27 @@ def get_raster_index(self): 2D array of grid indices """ if self.raster_file is not None and os.path.exists(self.raster_file): - logger.debug(f'Loading raster index: {self.raster_file} ' - f'for {self.input_file_info}') + logger.debug( + f'Loading raster index: {self.raster_file} ' + f'for {self.input_file_info}' + ) raster_index = np.loadtxt(self.raster_file).astype(np.uint32) else: - check = (self.grid_shape is not None and self.target is not None) - msg = ('Must provide raster file or shape + target to get ' - 'raster index') + check = self.grid_shape is not None and self.target is not None + msg = ( + 'Must provide raster file or shape + target to get ' + 'raster index' + ) assert check, msg - logger.debug('Calculating raster index from WTK file ' - f'for shape {self.grid_shape} and target ' - f'{self.target}') + logger.debug( + 'Calculating raster index from WTK file ' + f'for shape {self.grid_shape} and target ' + f'{self.target}' + ) handle = self.source_handler(self.file_paths[0]) - raster_index = handle.get_raster_index(self.target, - self.grid_shape, - max_delta=self.max_delta) + raster_index = handle.get_raster_index( + self.target, self.grid_shape, max_delta=self.max_delta + ) if self.raster_file is not None: basedir = os.path.dirname(self.raster_file) if not os.path.exists(basedir): @@ -2773,10 +3072,12 @@ class DataHandlerH5WindCC(DataHandlerH5): # model but are not part of the synthetic output and are not sent to the # discriminator. These are case-insensitive and follow the Unix shell-style # wildcard format. - TRAIN_ONLY_FEATURES = ('temperature_max_*m', - 'temperature_min_*m', - 'relativehumidity_max_*m', - 'relativehumidity_min_*m') + TRAIN_ONLY_FEATURES = ( + 'temperature_max_*m', + 'temperature_min_*m', + 'relativehumidity_max_*m', + 'relativehumidity_min_*m', + ) def __init__(self, *args, **kwargs): """ @@ -2791,17 +3092,22 @@ def __init__(self, *args, **kwargs): t_shape = sample_shape[-1] if len(sample_shape) == 2: - logger.info('Found 2D sample shape of {}. Adding spatial dim of 24' - .format(sample_shape)) + logger.info( + 'Found 2D sample shape of {}. Adding spatial dim of 24'.format( + sample_shape + ) + ) sample_shape = (*sample_shape, 24) t_shape = sample_shape[-1] kwargs['sample_shape'] = sample_shape if t_shape < 24 or t_shape % 24 != 0: - msg = ('Climate Change DataHandler can only work with temporal ' - 'sample shapes that are one or more days of hourly data ' - '(e.g. 24, 48, 72...). The requested temporal sample ' - 'shape was: {}'.format(t_shape)) + msg = ( + 'Climate Change DataHandler can only work with temporal ' + 'sample shapes that are one or more days of hourly data ' + '(e.g. 24, 48, 72...). The requested temporal sample ' + 'shape was: {}'.format(t_shape) + ) logger.error(msg) raise RuntimeError(msg) @@ -2816,24 +3122,31 @@ def __init__(self, *args, **kwargs): def run_daily_averages(self): """Calculate daily average data and store as attribute.""" - msg = ('Data needs to be hourly with at least 24 hours, but data ' - 'shape is {}.'.format(self.data.shape)) + msg = ( + 'Data needs to be hourly with at least 24 hours, but data ' + 'shape is {}.'.format(self.data.shape) + ) assert self.data.shape[2] % 24 == 0, msg assert self.data.shape[2] > 24, msg n_data_days = int(self.data.shape[2] / 24) - daily_data_shape = (self.data.shape[0:2] + (n_data_days,) - + (self.data.shape[3],)) + daily_data_shape = ( + self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) + ) - logger.info('Calculating daily average datasets for {} training ' - 'data days.'.format(n_data_days)) + logger.info( + 'Calculating daily average datasets for {} training ' + 'data days.'.format(n_data_days) + ) self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) - self.daily_data_slices = np.array_split(np.arange(self.data.shape[2]), - n_data_days) - self.daily_data_slices = [slice(x[0], x[-1] + 1) - for x in self.daily_data_slices] + self.daily_data_slices = np.array_split( + np.arange(self.data.shape[2]), n_data_days + ) + self.daily_data_slices = [ + slice(x[0], x[-1] + 1) for x in self.daily_data_slices + ] for idf, fname in enumerate(self.features): for d, t_slice in enumerate(self.daily_data_slices): if '_max_' in fname: @@ -2844,11 +3157,14 @@ def run_daily_averages(self): self.daily_data[:, :, d, idf] = tmp[:, :] else: tmp = daily_temporal_coarsening( - self.data[:, :, t_slice, idf], temporal_axis=2) + self.data[:, :, t_slice, idf], temporal_axis=2 + ) self.daily_data[:, :, d, idf] = tmp[:, :, 0] - logger.info('Finished calculating daily average datasets for {} ' - 'training data days.'.format(n_data_days)) + logger.info( + 'Finished calculating daily average datasets for {} ' + 'training data days.'.format(n_data_days) + ) def _normalize_data(self, feature_index, mean, std): """Normalize data with initialized mean and standard deviation for a @@ -2876,15 +3192,16 @@ def feature_registry(cls): dict Method registry """ - registry = {'U_(.*)m': UWind, - 'V_(.*)m': VWind, - 'lat_lon': LatLonH5, - 'topography': TopoH5, - 'temperature_max_(.*)m': 'temperature_(.*)m', - 'temperature_min_(.*)m': 'temperature_(.*)m', - 'relativehumidity_max_(.*)m': 'relativehumidity_(.*)m', - 'relativehumidity_min_(.*)m': 'relativehumidity_(.*)m', - } + registry = { + 'U_(.*)m': UWind, + 'V_(.*)m': VWind, + 'lat_lon': LatLonH5, + 'topography': TopoH5, + 'temperature_max_(.*)m': 'temperature_(.*)m', + 'temperature_min_(.*)m': 'temperature_(.*)m', + 'relativehumidity_max_(.*)m': 'relativehumidity_(.*)m', + 'relativehumidity_min_(.*)m': 'relativehumidity_(.*)m', + } return registry def get_observation_index(self): @@ -2909,11 +3226,13 @@ def get_observation_index(self): t_slice_hourly = slice(t_slice_0.start, t_slice_1.stop) t_slice_daily = slice(rand_day_ind, rand_day_ind + n_days) - obs_ind_hourly = tuple([*spatial_slice, t_slice_hourly, - np.arange(len(self.features))]) + obs_ind_hourly = tuple( + [*spatial_slice, t_slice_hourly, np.arange(len(self.features))] + ) - obs_ind_daily = tuple([*spatial_slice, t_slice_daily, - np.arange(len(self.features))]) + obs_ind_daily = tuple( + [*spatial_slice, t_slice_daily, np.arange(len(self.features))] + ) return obs_ind_hourly, obs_ind_daily @@ -2960,9 +3279,11 @@ def split_data(self, data=None): if data is not None: self.data = data - midnight_ilocs = np.where((self.time_index.hour == 0) - & (self.time_index.minute == 0) - & (self.time_index.second == 0))[0] + midnight_ilocs = np.where( + (self.time_index.hour == 0) + & (self.time_index.minute == 0) + & (self.time_index.second == 0) + )[0] n_val_obs = int(np.ceil(self.val_split * len(midnight_ilocs))) val_split_index = midnight_ilocs[n_val_obs] @@ -3003,11 +3324,13 @@ def __init__(self, *args, **kwargs): required = ['ghi', 'clearsky_ghi', 'clearsky_ratio'] missing = [dset for dset in required if dset not in args[1]] if any(missing): - msg = ('Cannot initialize DataHandlerH5SolarCC without required ' - 'features {}. All three are necessary to get the daily ' - 'average clearsky ratio (ghi sum / clearsky ghi sum), even ' - 'though only the clearsky ratio will be passed to the GAN.' - .format(required)) + msg = ( + 'Cannot initialize DataHandlerH5SolarCC without required ' + 'features {}. All three are necessary to get the daily ' + 'average clearsky ratio (ghi sum / clearsky ghi sum), even ' + 'though only the clearsky ratio will be passed to the ' + 'GAN.'.format(required) + ) logger.error(msg) raise KeyError(msg) @@ -3030,7 +3353,8 @@ def feature_registry(cls): 'lat_lon': LatLonH5, 'cloud_mask': CloudMaskH5, 'clearsky_ratio': ClearSkyRatioH5, - 'topography': TopoH5} + 'topography': TopoH5, + } return registry def run_daily_averages(self): @@ -3042,24 +3366,31 @@ def run_daily_averages(self): instantaneous hourly clearsky ratios """ - msg = ('Data needs to be hourly with at least 24 hours, but data ' - 'shape is {}.'.format(self.data.shape)) + msg = ( + 'Data needs to be hourly with at least 24 hours, but data ' + 'shape is {}.'.format(self.data.shape) + ) assert self.data.shape[2] % 24 == 0, msg assert self.data.shape[2] > 24, msg n_data_days = int(self.data.shape[2] / 24) - daily_data_shape = (self.data.shape[0:2] + (n_data_days,) - + (self.data.shape[3],)) + daily_data_shape = ( + self.data.shape[0:2] + (n_data_days,) + (self.data.shape[3],) + ) - logger.info('Calculating daily average datasets for {} training ' - 'data days.'.format(n_data_days)) + logger.info( + 'Calculating daily average datasets for {} training ' + 'data days.'.format(n_data_days) + ) self.daily_data = np.zeros(daily_data_shape, dtype=np.float32) - self.daily_data_slices = np.array_split(np.arange(self.data.shape[2]), - n_data_days) - self.daily_data_slices = [slice(x[0], x[-1] + 1) - for x in self.daily_data_slices] + self.daily_data_slices = np.array_split( + np.arange(self.data.shape[2]), n_data_days + ) + self.daily_data_slices = [ + slice(x[0], x[-1] + 1) for x in self.daily_data_slices + ] i_ghi = self.features.index('ghi') i_cs = self.features.index('clearsky_ghi') @@ -3068,7 +3399,8 @@ def run_daily_averages(self): for d, t_slice in enumerate(self.daily_data_slices): for idf in range(self.data.shape[-1]): self.daily_data[:, :, d, idf] = daily_temporal_coarsening( - self.data[:, :, t_slice, idf], temporal_axis=2)[:, :, 0] + self.data[:, :, t_slice, idf], temporal_axis=2 + )[:, :, 0] # note that this ratio of daily irradiance sums is not the same as # the average of hourly ratios. @@ -3079,26 +3411,32 @@ def run_daily_averages(self): # remove ghi and clearsky ghi from feature set. These shouldn't be used # downstream for solar cc and keeping them confuses the batch handler - logger.info('Finished calculating daily average clearsky_ratio, ' - 'removing ghi and clearsky_ghi from the ' - 'DataHandlerH5SolarCC feature list.') - ifeats = np.array([i for i in range(len(self.features)) - if i not in (i_ghi, i_cs)]) + logger.info( + 'Finished calculating daily average clearsky_ratio, ' + 'removing ghi and clearsky_ghi from the ' + 'DataHandlerH5SolarCC feature list.' + ) + ifeats = np.array( + [i for i in range(len(self.features)) if i not in (i_ghi, i_cs)] + ) self.data = self.data[..., ifeats] self.daily_data = self.daily_data[..., ifeats] self.features.remove('ghi') self.features.remove('clearsky_ghi') - logger.info('Finished calculating daily average datasets for {} ' - 'training data days.'.format(n_data_days)) + logger.info( + 'Finished calculating daily average datasets for {} ' + 'training data days.'.format(n_data_days) + ) # pylint: disable=W0223 class DataHandlerDC(DataHandler): """Data-centric data handler""" - def get_observation_index(self, temporal_weights=None, - spatial_weights=None): + def get_observation_index( + self, temporal_weights=None, spatial_weights=None + ): """Randomly gets weighted spatial sample and time sample Parameters @@ -3117,22 +3455,25 @@ def get_observation_index(self, temporal_weights=None, Used to get single observation like self.data[observation_index] """ if spatial_weights is not None: - spatial_slice = weighted_box_sampler(self.data, - self.sample_shape[:2], - weights=spatial_weights) + spatial_slice = weighted_box_sampler( + self.data, self.sample_shape[:2], weights=spatial_weights + ) else: - spatial_slice = uniform_box_sampler(self.data, - self.sample_shape[:2]) + spatial_slice = uniform_box_sampler( + self.data, self.sample_shape[:2] + ) if temporal_weights is not None: - temporal_slice = weighted_time_sampler(self.data, - self.sample_shape[2], - weights=temporal_weights) + temporal_slice = weighted_time_sampler( + self.data, self.sample_shape[2], weights=temporal_weights + ) else: - temporal_slice = uniform_time_sampler(self.data, - self.sample_shape[2]) + temporal_slice = uniform_time_sampler( + self.data, self.sample_shape[2] + ) return tuple( - [*spatial_slice, temporal_slice, np.arange(len(self.features))]) + [*spatial_slice, temporal_slice, np.arange(len(self.features))] + ) def get_next(self, temporal_weights=None, spatial_weights=None): """Get data for observation using weighted random observation index. @@ -3154,7 +3495,8 @@ def get_next(self, temporal_weights=None, spatial_weights=None): (spatial_1, spatial_2, temporal, features) """ self.current_obs_index = self.get_observation_index( - temporal_weights=temporal_weights, spatial_weights=spatial_weights) + temporal_weights=temporal_weights, spatial_weights=spatial_weights + ) observation = self.data[self.current_obs_index] return observation diff --git a/sup3r/qa/qa.py b/sup3r/qa/qa.py index c9ae287e6..8622520e5 100644 --- a/sup3r/qa/qa.py +++ b/sup3r/qa/qa.py @@ -338,7 +338,7 @@ def output_handler_class(self): elif self.output_type == 'h5': return Resource - def bias_correct_source_data(self, data, source_feature): + def bias_correct_source_data(self, data, lat_lon, source_feature): """Bias correct data using a method defined by the bias_correct_method input to ForwardPassStrategy @@ -347,6 +347,10 @@ def bias_correct_source_data(self, data, source_feature): data : np.ndarray Any source data to be bias corrected, with the feature channel in the last axis. + lat_lon : np.ndarray + Latitude longitude array for the given data. Used to get the + correct bc factors for the appropriate domain. + (n_lats, n_lons, 2) source_feature : str | list The source feature name corresponding to the output feature name @@ -383,7 +387,7 @@ def bias_correct_source_data(self, data, source_feature): 'function: {} with kwargs: {}' .format(source_feature, method, feature_kwargs)) - data = method(data, **feature_kwargs) + data = method(data, lat_lon, **feature_kwargs) return data @@ -403,6 +407,7 @@ def get_source_dset(self, feature, source_feature): Low-res source input data including optional bias correction """ + lat_lon = self.source_handler.lat_lon if 'windspeed' in feature and len(source_feature) == 2: u_feat, v_feat = source_feature logger.info('For sup3r output feature "{}", retrieving u/v ' @@ -412,13 +417,13 @@ def get_source_dset(self, feature, source_feature): v_idf = self.source_handler.features.index(v_feat) u_true = self.source_handler.data[..., u_idf] v_true = self.source_handler.data[..., v_idf] - u_true = self.bias_correct_source_data(u_true, u_feat) - v_true = self.bias_correct_source_data(v_true, v_feat) + u_true = self.bias_correct_source_data(u_true, lat_lon, u_feat) + v_true = self.bias_correct_source_data(v_true, lat_lon, v_feat) data_true = np.hypot(u_true, v_true) else: idf = self.source_handler.features.index(source_feature) data_true = self.source_handler.data[..., idf] - data_true = self.bias_correct_source_data(data_true, + data_true = self.bias_correct_source_data(data_true, lat_lon, source_feature) return data_true diff --git a/sup3r/utilities/era_downloader.py b/sup3r/utilities/era_downloader.py new file mode 100644 index 000000000..8aff26114 --- /dev/null +++ b/sup3r/utilities/era_downloader.py @@ -0,0 +1,1214 @@ +"""Download ERA5 file for the given year and month + +NOTE: To use this you need to have cdsapi package installed and a ~/.cdsapirc +file with a url and api key. Follow the instructions here: +https://cds.climate.copernicus.eu/api-how-to +""" + +import logging +import os +from calendar import monthrange +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) +from glob import glob +from typing import ClassVar +from warnings import warn + +import numpy as np +import pandas as pd +import xarray as xr +from netCDF4 import Dataset + +from sup3r.utilities.interpolate_log_profile import LogLinInterpolator + +logger = logging.getLogger(__name__) + +try: + import cdsapi + + CDS_API_CLIENT = cdsapi.Client() +except ImportError as e: + msg = f'Could not import cdsapi package. {e}' + logger.error(msg) + + +class EraDownloader: + """Class to handle ERA5 downloading, variable renaming, file combination, + and interpolation.""" + + msg = ( + 'To download ERA5 data you need to have a ~/.cdsapirc file ' + 'with a valid url and api key. Follow the instructions here: ' + 'https://cds.climate.copernicus.eu/api-how-to' + ) + req_file = os.path.join(os.path.expanduser('~'), '.cdsapirc') + assert os.path.exists(req_file), msg + + DEFAULT_RENAMED_VARS: ClassVar[list] = [ + 'zg', + 'orog', + 'u', + 'v', + 'u_10m', + 'v_10m', + 'u_100m', + 'v_100m', + ] + DEFAULT_DOWNLOAD_VARS: ClassVar[list] = [ + '10m_u_component_of_wind', + '10m_v_component_of_wind', + '100m_u_component_of_wind', + '100m_v_component_of_wind', + 'u_component_of_wind', + 'v_component_of_wind', + ] + + SFC_VARS: ClassVar[list] = [ + '10m_u_component_of_wind', + '10m_v_component_of_wind', + '100m_u_component_of_wind', + '100m_v_component_of_wind', + 'surface_pressure', + '2m_temperature', + 'geopotential', + ] + LEVEL_VARS: ClassVar[list] = [ + 'u_component_of_wind', + 'v_component_of_wind', + 'geopotential', + 'temperature', + ] + NAME_MAP: ClassVar[dict] = { + 'u10': 'u_10m', + 'v10': 'v_10m', + 'u100': 'u_100m', + 'v100': 'v_100m', + 't': 'temperature', + 't2m': 'temperature_2m', + 'u': 'u', + 'v': 'v', + 'sp': 'pressure_0m', + } + + def __init__( + self, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + variables=None, + ): + """Initialize the class. + + Parameters + ---------- + year : int + Year of data to download. + month : int + Month of data to download. + area : list + Domain area of the data to download. + [max_lat, min_lon, min_lat, max_lon] + levels : list + List of pressure levels to download. + combined_out_pattern : str + Pattern for combined monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_combined.nc' + interp_out_pattern : str | None + Pattern for interpolated monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_interp.nc' + run_interp : bool + Whether to run interpolation after downloading and combining files. + overwrite : bool + Whether to overwrite existing files. + required_shape : tuple | None + Required shape of data to download. Used to check downloaded data. + Should be (n_levels, n_lats, n_lons). If None, no check is + performed. + variables : list | None + Variables to download. If None this defaults to just gepotential + and wind components. + """ + self.year = year + self.month = month + self.area = area + self.levels = levels + self.run_interp = run_interp + self.overwrite = overwrite + self.combined_out_pattern = combined_out_pattern + self.variables = ( + variables if variables is not None else self.DEFAULT_DOWNLOAD_VARS + ) + self.days = [ + str(n).zfill(2) + for n in np.arange(1, monthrange(year, month)[1] + 1) + ] + self.hours = [str(n).zfill(2) + ":00" for n in range(0, 24)] + + if required_shape is None or len(required_shape) == 3: + self.required_shape = required_shape + elif len(required_shape) == 2 and len(levels) != required_shape[0]: + self.required_shape = (len(levels), *required_shape) + else: + msg = f'Received weird required_shape: {required_shape}.' + logger.error(msg) + raise OSError(msg) + + self.interp_file = None + if interp_out_pattern is not None and run_interp: + self.interp_file = interp_out_pattern.format( + year=year, month=str(month).zfill(2) + ) + os.makedirs(os.path.dirname(self.interp_file), exist_ok=True) + + self.combined_file = combined_out_pattern.format( + year=year, month=str(month).zfill(2) + ) + os.makedirs(os.path.dirname(self.combined_file), exist_ok=True) + basedir = os.path.dirname(self.combined_file) + self.surface_file = os.path.join( + basedir, f'sfc_{year}_{str(month).zfill(2)}.nc' + ) + self.level_file = os.path.join( + basedir, f'levels_{year}_{str(month).zfill(2)}.nc' + ) + self.sfc_file_variables = [] + self.level_file_variables = [] + self.check_good_vars(self.variables) + self.prep_var_lists(variables) + + msg = ( + 'Initialized EraDownloader with: ' + f'year={self.year}, month={self.month}, area={self.area}, ' + f'levels={self.levels}, variables={self.variables}' + ) + logger.info(msg) + + @classmethod + def init_dims(cls, old_ds, new_ds, dims): + """Initialize dimensions in new dataset from old dataset + + Parameters + ---------- + old_ds : Dataset + Dataset() object from old file + new_ds : Dataset + Dataset() object for new file + dims : tuple + Tuple of dimensions. e.g. ('time', 'latitude', 'longitude') + + Returns + ------- + new_ds : Dataset + Dataset() object for new file with dimensions initialized. + """ + for var in dims: + new_ds.createDimension(var, len(old_ds[var])) + _ = new_ds.createVariable(var, old_ds[var].dtype, dimensions=var) + new_ds[var][:] = old_ds[var][:] + new_ds[var].units = old_ds[var].units + return new_ds + + @classmethod + def get_tmp_file(cls, file): + """Get temp file for given file. Then only needed variables will be + written to the given file.""" + tmp_file = file.replace(".nc", "_tmp.nc") + return tmp_file + + def check_good_vars(self, variables): + """Make sure requested variables are valid. + + Parameters + ---------- + variables : list + List of variables to download. Can be any of ['u', 'v', 'pressure', + temperature'] + """ + valid_vars = ['u', 'v', 'pressure', 'temperature'] + good = all(var in valid_vars for var in variables) + if not good: + msg = ( + f'Received variables {variables} not in valid variables ' + f'list {valid_vars}' + ) + logger.error(msg) + raise OSError(msg) + + def _prep_var_lists(self, variables): + """Add all downloadable variables for the generic requested variables. + e.g. if variable = 'u' add all downloadable u variables to list.""" + d_vars = [] + vars = variables.copy() + for i, v in enumerate(vars): + if v in ('u', 'v'): + vars[i] = f'{v}_' + for var in vars: + for d_var in self.SFC_VARS + self.LEVEL_VARS: + if var in d_var: + d_vars.append(d_var) + return d_vars + + def prep_var_lists(self, variables): + """Create surface and level variable lists based on requested + variables.""" + variables = self._prep_var_lists(variables) + for var in variables: + if var in self.SFC_VARS: + self.sfc_file_variables.append(var) + elif var in self.LEVEL_VARS: + self.level_file_variables.append(var) + else: + msg = f'Requested {var} is not available for download.' + logger.warning(msg) + warn(msg) + + def download_process_combine(self): + """Run the download routine.""" + sfc_check = len(self.sfc_file_variables) > 0 + level_check = ( + len(self.level_file_variables) > 0 and self.levels is not None + ) + if self.level_file_variables: + msg = ( + f'{self.level_file_variables} requested but no levels' + ' were provided.' + ) + if self.levels is None: + logger.warning(msg) + warn(msg) + if sfc_check: + self.download_surface_file() + if level_check: + self.download_levels_file() + if sfc_check and level_check: + self.process_and_combine() + + def download_levels_file(self): + """Download file with requested pressure levels""" + if not os.path.exists(self.level_file) or self.overwrite: + if 'geopotential' not in self.level_file_variables: + self.level_file_variables.append('geopotential') + msg = ( + f'Downloading {self.level_file_variables} to ' + f'{self.level_file}.' + ) + logger.info(msg) + CDS_API_CLIENT.retrieve( + 'reanalysis-era5-pressure-levels', + { + 'product_type': 'reanalysis', + 'format': 'netcdf', + 'variable': self.level_file_variables, + 'pressure_level': self.levels, + 'year': self.year, + 'month': self.month, + 'day': self.days, + 'time': self.hours, + 'area': self.area, + }, + self.level_file, + ) + else: + logger.info(f'File already exists: {self.level_file}.') + + def download_surface_file(self): + """Download surface file""" + if not os.path.exists(self.surface_file) or self.overwrite: + if 'geopotential' not in self.sfc_file_variables: + self.sfc_file_variables.append('geopotential') + msg = ( + f'Downloading {self.sfc_file_variables} to ' + f'{self.surface_file}.' + ) + logger.info(msg) + CDS_API_CLIENT.retrieve( + 'reanalysis-era5-single-levels', + { + 'product_type': 'reanalysis', + 'format': 'netcdf', + 'variable': self.sfc_file_variables, + 'year': self.year, + 'month': self.month, + 'day': self.days, + 'time': self.hours, + 'area': self.area, + }, + self.surface_file, + ) + else: + logger.info(f'File already exists: {self.surface_file}.') + + def process_surface_file(self): + """Rename variables and convert geopotential to geopotential height.""" + + dims = ('time', 'latitude', 'longitude') + tmp_file = self.get_tmp_file(self.surface_file) + with Dataset(self.surface_file, "r") as old_ds: + with Dataset(tmp_file, "w") as ds: + ds = self.init_dims(old_ds, ds, dims) + + ds = self.convert_z('orog', 'Orography', old_ds, ds) + + ds = self.map_vars(old_ds, ds) + os.system(f'mv {tmp_file} {self.surface_file}') + logger.info( + f'Finished processing {self.surface_file}. Moved ' + f'{tmp_file} to {self.surface_file}.' + ) + + def map_vars(self, old_ds, ds): + """Map variables from old dataset to new dataset + + Parameters + ---------- + old_ds : Dataset + Dataset() object from old file + ds : Dataset + Dataset() object for new file + + Returns + ------- + ds : Dataset + Dataset() object for new file with new variables written. + """ + for old_name, new_name in self.NAME_MAP.items(): + if old_name in old_ds.variables: + _ = ds.createVariable( + new_name, + np.float32, + dimensions=old_ds[old_name].dimensions, + ) + vals = old_ds.variables[old_name][:] + if 'temperature' in new_name: + vals -= 273.15 + ds.variables[new_name][:] = vals + return ds + + def convert_z(self, standard_name, long_name, old_ds, ds): + """Convert z to given height variable + + Parameters + ---------- + standard_name : str + New variable name. e.g. 'zg' or 'orog' + long_name : str + Long name for new variable. e.g. 'Geopotential Height' or + 'Orography' + old_ds : Dataset + Dataset() object from tmp file + ds : Dataset + Dataset() object for new file + + Returns + ------- + ds : Dataset + Dataset() object for new file with new height variable written. + """ + + _ = ds.createVariable( + standard_name, np.float32, dimensions=old_ds['z'].dimensions + ) + ds.variables[standard_name][:] = old_ds['z'][:] / 9.81 + ds.variables[standard_name].long_name = long_name + ds.variables[standard_name].standard_name = 'zg' + ds.variables[standard_name].units = 'm' + return ds + + def process_level_file(self): + """Convert geopotential to geopotential height.""" + + dims = ('time', 'level', 'latitude', 'longitude') + tmp_file = self.get_tmp_file(self.level_file) + with Dataset(self.level_file, "r") as old_ds: + with Dataset(tmp_file, "w") as ds: + ds = self.init_dims(old_ds, ds, dims) + + ds = self.convert_z('zg', 'Geopotential Height', old_ds, ds) + + ds = self.map_vars(old_ds, ds) + + if 'pressure' in self.variables: + tmp = np.zeros(ds.variables['zg'].shape) + for i in range(tmp.shape[1]): + tmp[:, i, :, :] = ds.variables['level'][i] * 100 + _ = ds.createVariable( + 'pressure', np.float32, dimensions=dims + ) + ds.variables['pressure'][:] = tmp[...] + ds.variables['pressure'].long_name = 'Pressure' + ds.variables['pressure'].units = 'Pa' + + os.system(f'mv {tmp_file} {self.level_file}') + logger.info( + f'Finished processing {self.level_file}. Moved ' + f'{tmp_file} to {self.level_file}.' + ) + + def process_and_combine(self): + """Process variables and combine.""" + + if not os.path.exists(self.combined_file) or self.overwrite: + logger.info(f'Processing {self.level_file}.') + self.process_level_file() + logger.info(f'Processing {self.surface_file}.') + self.process_surface_file() + logger.info( + f'Combining {self.level_file} and {self.surface_file} ' + f'to {self.combined_file}.' + ) + with xr.open_mfdataset([self.level_file, self.surface_file]) as ds: + ds.to_netcdf(self.combined_file) + logger.info(f'Finished writing {self.combined_file}') + os.remove(self.level_file) + os.remove(self.surface_file) + + def good_file(self, file, required_shape): + """Check if file has the required shape and variables. + + Parameters + ---------- + file : str + Name of file to check for required variables and shape + required_shape : tuple + Required shape for data. Should be (n_levels, n_lats, n_lons). + + Returns + ------- + bool + Whether or not data has required shape and variables. + """ + out = self.check_single_file( + file, + check_nans=False, + check_heights=False, + required_shape=required_shape, + ) + good_vars, good_shape, _, _ = out + check = good_vars and good_shape + return check + + def check_existing_files(self): + """If files exist already check them for good shape and required + variables. Remove them if there was a problem so we can continue with + routine from scratch.""" + if os.path.exists(self.combined_file): + try: + check = self.good_file(self.combined_file, self.required_shape) + if not check: + msg = f'Bad file: {self.combined_file}' + logger.error(msg) + raise OSError(msg) + else: + if os.path.exists(self.level_file): + os.remove(self.level_file) + if os.path.exists(self.surface_file): + os.remove(self.surface_file) + logger.info( + f'{self.combined_file} already exists and ' + f'overwrite={self.overwrite}. Skipping.' + ) + except Exception as e: + logger.info(f'Something wrong with {self.combined_file}. {e}') + if os.path.exists(self.combined_file): + os.remove(self.combined_file) + check = self.interp_file is not None and os.path.exists( + self.interp_file + ) + if check: + os.remove(self.interp_file) + + def run_interpolation(self, max_workers=None, **kwargs): + """Run interpolation to get final final. Runs log interpolation up to + max_log_height (usually 100m) and linear interpolation above this.""" + LogLinInterpolator.run( + infile=self.combined_file, + outfile=self.interp_file, + max_workers=max_workers, + variables=self.variables, + overwrite=self.overwrite, + **kwargs, + ) + + def get_monthly_file(self, interp_workers=None, **interp_kwargs): + """Download level and surface files, process variables, and combine + processed files. Includes checks for shape and variables and option to + interpolate.""" + + if os.path.exists(self.combined_file) and self.overwrite: + os.remove(self.combined_file) + + self.check_existing_files() + + if not os.path.exists(self.combined_file): + self.download_process_combine() + + if self.run_interp: + self.run_interpolation(max_workers=interp_workers, **interp_kwargs) + + if os.path.exists(self.interp_file): + if self.already_pruned(self.interp_file): + logger.info(f'{self.interp_file} pruned already.') + else: + self.prune_output(self.interp_file) + + @classmethod + def all_months_exist(cls, year, file_pattern): + """Check if all months in the requested year exist. + + Parameters + ---------- + year : int + Year of data to download. + file_pattern : str + Pattern for monthly output file. Must include year and month format + keys. e.g. 'era5_{year}_{month}_combined.nc' + + Returns + ------- + bool + True if all months in the requested year exist. + """ + return all( + os.path.exists( + file_pattern.format(year=year, month=str(month).zfill(2)) + ) + for month in range(1, 13) + ) + + @classmethod + def already_pruned(cls, infile): + """Check if file has been pruned already.""" + + keep_vars = ( + 'u_', + 'v_', + 'pressure_', + 'temperature_', + 'orog', + 'time', + 'latitude', + 'longitude', + ) + pruned = True + with Dataset(infile, 'r') as ds: + for var in ds.variables: + if not any(name in var for name in keep_vars): + logger.info(f'Pruning {var} in {infile}.') + pruned = False + return pruned + + @classmethod + def prune_output(cls, infile): + """Prune output file to keep just single level variables""" + + logger.info(f'Pruning {infile}.') + tmp_file = cls.get_tmp_file(infile) + keep_vars = ('u_', 'v_', 'pressure_', 'temperature_', 'orog') + with Dataset(infile, 'r') as old_ds: + with Dataset(tmp_file, 'w') as new_ds: + new_ds = cls.init_dims( + old_ds, new_ds, ('time', 'latitude', 'longitude') + ) + for var in old_ds.variables: + if any(name in var for name in keep_vars): + old_var = old_ds[var] + vals = old_var[:] + _ = new_ds.createVariable( + var, old_var.dtype, dimensions=old_var.dimensions + ) + new_ds[var][:] = vals + if hasattr(old_var, 'units'): + new_ds[var].units = old_var.units + if hasattr(old_var, 'standard_name'): + standard_name = old_var.standard_name + new_ds[var].standard_name = standard_name + if hasattr(old_var, 'long_name'): + new_ds[var].long_name = old_var.long_name + os.system(f'mv {tmp_file} {infile}') + logger.info( + f'Finished pruning variables in {infile}. Moved ' + f'{tmp_file} to {infile}.' + ) + + @classmethod + def run_month( + cls, + year, + month, + area, + levels, + combined_out_pattern, + interp_out_pattern=None, + run_interp=True, + overwrite=False, + required_shape=None, + interp_workers=None, + variables=None, + **interp_kwargs, + ): + """Run routine for all months in the requested year. + + Parameters + ---------- + year : int + Year of data to download. + month : int + Month of data to download. + area : list + Domain area of the data to download. + [max_lat, min_lon, min_lat, max_lon] + levels : list + List of pressure levels to download. + combined_out_pattern : str + Pattern for combined monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_combined.nc' + interp_out_pattern : str | None + Pattern for interpolated monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_interp.nc' + run_interp : bool + Whether to run interpolation after downloading and combining files. + overwrite : bool + Whether to overwrite existing files. + required_shape : tuple | None + Required shape of data to download. Used to check downloaded data. + Should be (n_levels, n_lats, n_lons). If None, no check is + performed. + interp_workers : int | None + Max number of workers to use for interpolation. + variables : list | None + Variables to download. If None this defaults to just gepotential + and wind components. + **interp_kwargs : dict + Keyword args for LogLinInterpolator.run() + """ + downloader = cls( + year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + variables=variables, + ) + downloader.get_monthly_file( + interp_workers=interp_workers, **interp_kwargs + ) + + @classmethod + def run_year( + cls, + year, + area, + levels, + combined_out_pattern, + combined_yearly_file, + interp_out_pattern=None, + interp_yearly_file=None, + run_interp=True, + overwrite=False, + required_shape=None, + max_workers=None, + interp_workers=None, + variables=None, + **interp_kwargs, + ): + """Run routine for all months in the requested year. + + Parameters + ---------- + year : int + Year of data to download. + area : list + Domain area of the data to download. + [max_lat, min_lon, min_lat, max_lon] + levels : list + List of pressure levels to download. + combined_out_pattern : str + Pattern for combined monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_combined.nc' + combined_yearly_file : str + Name of yearly file made from monthly combined files. + interp_out_pattern : str | None + Pattern for interpolated monthly output file. Must include year and + month format keys. e.g. 'era5_{year}_{month}_interp.nc' + interp_yearly_file : str + Name of yearly file made from monthly interp files. + run_interp : bool + Whether to run interpolation after downloading and combining files. + overwrite : bool + Whether to overwrite existing files. + required_shape : tuple | None + Required shape of data to download. Used to check downloaded data. + Should be (n_levels, n_lats, n_lons). If None, no check is + performed. + max_workers : int + Max number of workers to use for downloading and processing monthly + files. + interp_workers : int | None + Max number of workers to use for interpolation. + variables : list | None + Variables to download. If None this defaults to just gepotential + and wind components. + **interp_kwargs : dict + Keyword args for LogLinInterpolator.run() + """ + if max_workers == 1: + for month in range(1, 13): + cls.run_month( + year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + interp_workers=interp_workers, + variables=variables, + **interp_kwargs, + ) + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for month in range(1, 13): + future = exe.submit( + cls.run_month, + year=year, + month=month, + area=area, + levels=levels, + combined_out_pattern=combined_out_pattern, + interp_out_pattern=interp_out_pattern, + run_interp=run_interp, + overwrite=overwrite, + required_shape=required_shape, + interp_workers=interp_workers, + variables=variables, + **interp_kwargs, + ) + futures[future] = {'year': year, 'month': month} + logger.info( + f'Submitted future for year {year} and month ' + f'{month}.' + ) + for future in as_completed(futures): + future.result() + v = futures[future] + logger.info( + f'Finished future for year {v["year"]} and month ' + f'{v["month"]}.' + ) + + cls.make_yearly_file(year, combined_out_pattern, combined_yearly_file) + + if run_interp: + cls.make_yearly_file(year, interp_out_pattern, interp_yearly_file) + + @classmethod + def make_yearly_file(cls, year, file_pattern, yearly_file): + """Combine monthly files into a single file. + + Parameters + ---------- + year : int + Year of monthly data to make into a yearly file. + file_pattern : str + File pattern for monthly files. Must have year and month format + keys. e.g. './era_uv_{year}_{month}_combined.nc' + yearly_file : str + Name of yearly file made from monthly files. + """ + msg = ( + f'Not all monthly files with file_patten {file_pattern} for ' + f'year {year} exist.' + ) + assert cls.all_months_exist(year, file_pattern), msg + + files = [ + file_pattern.format(year=year, month=str(month).zfill(2)) + for month in range(1, 13) + ] + + if not os.path.exists(yearly_file): + with xr.open_mfdataset(files) as res: + logger.info(f'Combining {files}') + os.makedirs(os.path.dirname(yearly_file), exist_ok=True) + res.to_netcdf(yearly_file) + logger.info(f'Saved {yearly_file}') + else: + logger.info(f'{yearly_file} already exists.') + + @classmethod + def _check_single_file( + cls, + res, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10, + ): + """Make sure given files include the given variables. Check for NaNs + and required shape. + + Parameters + ---------- + res : xr.open_dataset() object + opened xarray data handler. + var_list : list + List of variables to check. + check_nans : bool + Whether to check data for NaNs. + check_heights : bool + Whether to check for heights above max interpolation height. + max_interp_height : int + Maximum height for interpolated output. Need raw heights above this + to avoid extrapolation. + required_shape : None | tuple + Required shape for data. Should be (n_levels, n_lats, n_lons). + If None the shape check will be skipped. + max_workers : int | None + Max number of workers to use in height check routine. + + Returns + ------- + good_vars : bool + Whether file includes all given variables + good_shape : bool + Whether shape matches required shape + good_hgts : bool + Whether there exists a height above the max interpolation height + for each spatial location and timestep + nan_pct : float + Percent of data which consists of NaNs across all given variables. + """ + good_vars = all(var in res for var in var_list) + res_shape = ( + *res['level'].shape, + *res['latitude'].shape, + *res['longitude'].shape, + ) + good_shape = ( + 'NA' if required_shape is None else (res_shape == required_shape) + ) + good_hgts = ( + 'NA' + if not check_heights + else cls.check_heights( + res, + max_interp_height=max_interp_height, + max_workers=max_workers, + ) + ) + nan_pct = ( + 'NA' if not check_nans else cls.get_nan_pct(res, var_list=var_list) + ) + + if not good_vars: + mask = np.array([var not in res for var in var_list]) + missing_vars = var_list[mask] + logger.error(f'Missing variables: {missing_vars}.') + if good_shape != 'NA' and not good_shape: + logger.error(f'Bad shape: {res_shape} != {required_shape}.') + + return good_vars, good_shape, good_hgts, nan_pct + + @classmethod + def check_heights(cls, res, max_interp_height=200, max_workers=10): + """Make sure there are heights higher than max interpolation height + + Parameters + ---------- + res : xr.open_dataset() object + opened xarray data handler. + max_interp_height : int + Maximum height for interpolated output. Need raw heights above this + to avoid extrapolation. + max_workers : int | None + Max number of workers to use for process pool height check. + + Returns + ------- + bool + Whether there is a height above max_interp_height for every spatial + location and timestep + """ + gp = res['zg'].values + sfc_hgt = np.repeat( + res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 + ) + heights = gp - sfc_hgt + heights = heights.reshape(heights.shape[0], heights.shape[1], -1) + checks = [] + logger.info( + f'Checking heights with max_interp_height={max_interp_height}.' + ) + + if max_workers == 1: + for idt in range(heights.shape[0]): + checks.append( + cls._check_heights_single_ts( + heights[idt], max_interp_height=max_interp_height + ) + ) + msg = f'Finished check for {idt + 1} of {heights.shape[0]}.' + logger.debug(msg) + else: + futures = [] + with ProcessPoolExecutor(max_workers=max_workers) as exe: + for idt in range(heights.shape[0]): + future = exe.submit( + cls._check_heights_single_ts, + heights[idt], + max_interp_height=max_interp_height, + ) + futures.append(future) + msg = ( + f'Submitted height check for {idt + 1} of ' + f'{heights.shape[0]}' + ) + logger.info(msg) + for i, future in enumerate(as_completed(futures)): + checks.append(future.result()) + msg = ( + f'Finished height check for {i + 1} of ' + f'{heights.shape[0]}' + ) + logger.info(msg) + + return all(checks) + + @classmethod + def _check_heights_single_ts(cls, heights, max_interp_height=200): + """Make sure there are heights higher than max interpolation height for + a single timestep + + Parameters + ---------- + heights : ndarray + Array of heights for single timestep and all spatial locations + max_interp_height : int + Maximum height for interpolated output. Need raw heights above this + to avoid extrapolation. + + Returns + ------- + bool + Whether there is a height above max_interp_height for every spatial + location + """ + checks = [any(h > max_interp_height) for h in heights.T] + return all(checks) + + @classmethod + def get_nan_pct(cls, res, var_list=None): + """Get percentage of data which consists of NaNs, across the given + variables + + Parameters + ---------- + res : xr.open_dataset() object + opened xarray data handler. + var_list : list + List of variables to check. + If None: ['zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', + 'u_100m', 'v_100m'] + + Returns + ------- + nan_pct : float + Percent of data which consists of NaNs across all given variables. + """ + elem_count = 0 + nan_count = 0 + for var in var_list: + logger.info(f'Checking NaNs for {var}.') + nans = np.isnan(res[var].values) + if nans.any(): + nan_count += nans.sum() + elem_count += nans.size + return 100 * nan_count / elem_count + + @classmethod + def check_single_file( + cls, + file, + var_list=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + required_shape=None, + max_workers=10, + ): + """Make sure given files include the given variables. Check for NaNs + and required shape. + + Parameters + ---------- + file : str + Name of file to check. + var_list : list + List of variables to check. + check_nans : bool + Whether to check data for NaNs. + check_heights : bool + Whether to check for heights above max interpolation height. + max_interp_height : int + Maximum height for interpolated output. Need raw heights above this + to avoid extrapolation. + required_shape : None | tuple + Required shape for data. Should be (n_levels, n_lats, n_lons). + If None the shape check will be skipped. + max_workers : int | None + Max number of workers to use for process pool height check. + + Returns + ------- + good_vars : bool + Whether file includes all given variables + good_shape : bool + Whether shape matches required shape + good_hgts : bool + Whether there is a height above max_interp_height for every spatial + location at every timestep. + nan_pct : float + Percent of data which consists of NaNs across all given variables. + """ + good = True + nan_pct = None + good_shape = None + good_vars = None + good_hgts = None + var_list = ( + var_list if var_list is not None else cls.DEFAULT_RENAMED_VARS + ) + try: + res = xr.open_dataset(file) + except Exception as e: + msg = f'Unable to open {file}. {e}' + logger.warning(msg) + warn(msg) + good = False + + if good: + out = cls._check_single_file( + res, + var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + required_shape=required_shape, + max_workers=max_workers, + ) + good_vars, good_shape, good_hgts, nan_pct = out + return good_vars, good_shape, good_hgts, nan_pct + + @classmethod + def run_files_checks( + cls, + file_pattern, + var_list=None, + required_shape=None, + check_nans=True, + check_heights=True, + max_interp_height=200, + max_workers=None, + height_check_workers=10, + ): + """Make sure given files include the given variables. Check for NaNs + and required shape. + + Parameters + ---------- + file_pattern : str | list + glob-able file pattern for files to check. + var_list : list | None + List of variables to check. If None: + ['zg', 'orog', 'u', 'v', 'u_10m', 'v_10m', 'u_100m', 'v_100m'] + required_shape : None | tuple + Required shape for data. Should include (n_levels, n_lats, n_lons). + If None the shape check will be skipped. + check_nans : bool + Whether to check data for NaNs. + check_heights : bool + Whether to check for heights above max interpolation height. + max_interp_height : int + Maximum height for interpolated output. Need raw heights above this + to avoid extrapolation. + max_workers : int | None + Number of workers to use for thread pool file checks. + height_check_workers : int | None + Number of workers to use for process pool height check. + + Returns + ------- + df : pd.DataFrame + DataFrame describing file check results. Has columns ['file', + 'good_vars', 'good_shape', 'good_hgts', 'nan_pct'] + + """ + if isinstance(file_pattern, str): + files = glob(file_pattern) + else: + files = file_pattern + df = pd.DataFrame( + columns=['file', 'good_vars', 'good_shape', 'good_hgts', 'nan_pct'] + ) + df['file'] = [os.path.basename(file) for file in files] + if max_workers == 1: + for i, file in enumerate(files): + logger.info(f'Checking {file}.') + out = cls.check_single_file( + file, + var_list=var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + max_workers=height_check_workers, + required_shape=required_shape, + ) + df.at[i, df.columns[1:]] = out + logger.info(f'Finished checking {file}.') + else: + futures = {} + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i, file in enumerate(files): + future = exe.submit( + cls.check_single_file, + file=file, + var_list=var_list, + check_nans=check_nans, + check_heights=check_heights, + max_interp_height=max_interp_height, + max_workers=height_check_workers, + required_shape=required_shape, + ) + msg = ( + f'Submitted file check future for {file}. Future ' + f'{i + 1} of {len(files)}.' + ) + logger.info(msg) + futures[future] = i + for i, future in enumerate(as_completed(futures)): + out = future.result() + df.at[futures[future], df.columns[1:]] = out + msg = ( + f'Finished checking {df["file"].iloc[futures[future]]}.' + f' Future {i + 1} of {len(files)}.' + ) + logger.info(msg) + return df diff --git a/sup3r/utilities/interpolate_log_profile.py b/sup3r/utilities/interpolate_log_profile.py new file mode 100644 index 000000000..bd70d9697 --- /dev/null +++ b/sup3r/utilities/interpolate_log_profile.py @@ -0,0 +1,723 @@ +"""Rescale ERA5 wind components according to log profile""" + +import logging +import os +from concurrent.futures import ( + ProcessPoolExecutor, + ThreadPoolExecutor, + as_completed, +) +from glob import glob +from warnings import warn + +import numpy as np +import xarray as xr +from netCDF4 import Dataset +from rex import init_logger +from scipy.interpolate import interp1d +from scipy.optimize import curve_fit + +from sup3r.utilities.interpolation import Interpolator + +init_logger(__name__, log_level='DEBUG') +init_logger('sup3r', log_level='DEBUG') + + +logger = logging.getLogger(__name__) + + +class LogLinInterpolator: + """Open ERA5 file, log interpolate wind components between 0 - + max_log_height, linearly interpolate components above max_log_height + meters, and save to file""" + + DEFAULT_OUTPUT_HEIGHTS = { + 'u': [40, 80, 120, 160, 200], + 'v': [40, 80, 120, 160, 200], + 'temperature': [10, 40, 80, 100, 120, 160, 200], + 'pressure': [0, 100, 200], + } + + def __init__( + self, + infile, + outfile, + output_heights=None, + variables=None, + max_log_height=100, + ): + """Initialize log interpolator. + + Parameters + ---------- + infile : str + Path to ERA5 data to use for windspeed log interpolation. Assumed + to contain zg, orog, and at least u/v at 10m and 100m. + outfile : str + Path to save output after log interpolation. + output_heights : None | dict + Dictionary of heights to interpolate to for each variables. + If None this defaults to DEFAULT_OUTPUT_HEIGHTS. + variables : list + List of variables to interpolate. If None this defaults to ['u', + 'v'] + max_log_height : int + Maximum height to use for log interpolation. Above this linear + interpolation will be used. + """ + self.infile = infile + self.outfile = outfile + + msg = ( + 'output_heights must be a dictionary with variables as keys ' + f'and lists of heights as values. Received: {output_heights}.' + ) + assert output_heights is None or isinstance(output_heights, dict), msg + + self.new_heights = output_heights or self.DEFAULT_OUTPUT_HEIGHTS + self.max_log_height = max_log_height + self.variables = ['u', 'v'] if variables is None else variables + self.data_dict = {} + self.new_data = {} + + msg = f'{self.infile} does not exist. Skipping.' + assert os.path.exists(self.infile), msg + + msg = ( + f'Initializing {self.__class__.__name__} with infile={infile}, ' + f'outfile={outfile}, new_heights={self.new_heights}, ' + f'variables={variables}.' + ) + logger.info(msg) + + def _load_single_var(self, variable): + """Load ERA5 data for the given variable. + + Parameters + ---------- + variable : str + Name of variable to load. (e.g. u, v, temperature) + + Returns + ------- + heights : ndarray + Array of heights for the given variable. Includes heights from + variables at single levels (e.g. u_10m). + var_arr : ndarray + Array of values for the given variable. Includes values from single + level fields for the given variable. (e.g. u_10m) + """ + logger.info(f'Loading {self.infile} for {variable}.') + with xr.open_dataset(self.infile) as res: + gp = res['zg'].values + sfc_hgt = np.repeat( + res['orog'].values[:, np.newaxis, ...], gp.shape[1], axis=1 + ) + heights = gp - sfc_hgt + + input_heights = [] + for var in res: + if f'{variable}_' in var: + height = var.split(f'{variable}_')[-1].strip('m') + input_heights.append(height) + + var_arr = [] + height_arr = [] + shape = (heights.shape[0], 1, *heights.shape[2:]) + for height in input_heights: + var_arr.append( + res[f'{variable}_{height}m'].values[:, np.newaxis, ...] + ) + height_arr.append(np.full(shape, height, dtype=np.float32)) + + if variable in res: + var_arr.append(res[f'{variable}'].values) + height_arr.append(heights) + var_arr = np.concatenate(var_arr, axis=1) + heights = np.concatenate(height_arr, axis=1) + + fixed_level_mask = np.full(heights.shape[1], True) + if variable in ('u', 'v'): + fixed_level_mask[:] = False + for i, _ in enumerate(input_heights): + fixed_level_mask[i] = True + + return heights, var_arr, fixed_level_mask + + def load(self): + """Load ERA5 data and create data arrays""" + self.data_dict = {} + for var in self.variables: + self.data_dict[var] = {} + out = self._load_single_var(var) + self.data_dict[var]['heights'] = out[0] + self.data_dict[var]['data'] = out[1] + self.data_dict[var]['mask'] = out[2] + + def interpolate_vars(self, max_workers=None): + """Interpolate u/v wind components below 100m using log profile. + Interpolate non wind data linearly.""" + for var, arrs in self.data_dict.items(): + max_log_height = self.max_log_height + if var not in ('u', 'v'): + max_log_height = -np.inf + logger.info( + f'Interpolating {var} to heights = {self.new_heights[var]}.' + ) + + self.new_data[var] = self.interp_var_to_height( + var_array=arrs['data'], + lev_array=arrs['heights'], + levels=self.new_heights[var], + fixed_level_mask=arrs['mask'], + max_log_height=max_log_height, + max_workers=max_workers, + ) + + def save_output(self): + """Save interpolated data to outfile""" + dirname = os.path.dirname(self.outfile) + os.makedirs(dirname, exist_ok=True) + os.system(f'cp {self.infile} {self.outfile}') + ds = Dataset(self.outfile, 'a') + logger.info(f'Creating {self.outfile}.') + for var, data in self.new_data.items(): + for i, height in enumerate(self.new_heights[var]): + name = f'{var}_{height}m' + logger.info(f'Adding {name} to {self.outfile}.') + if name not in ds.variables: + _ = ds.createVariable( + name, + np.float32, + dimensions=('time', 'latitude', 'longitude'), + ) + ds.variables[name][:] = data[i, ...] + ds.variables[name].long_name = f'{height} meter {var}' + + units = None + if 'u_' in var or 'v_' in var: + units = 'm s**-1' + if 'pressure' in var: + units = 'Pa' + if 'temperature' in var: + units = 'C' + + if units is not None: + ds.variables[name].units = units + + ds.close() + logger.info(f'Saved interpolated output to {self.outfile}.') + + @classmethod + def run( + cls, + infile, + outfile, + output_heights=None, + variables=None, + max_log_height=100, + overwrite=False, + max_workers=None, + ): + """Run interpolation and save output + + Parameters + ---------- + infile : str + Path to ERA5 data to use for windspeed log interpolation. Assumed + to contain zg, orog, and at least u/v at 10m and 100m. + outfile : str + Path to save output after log interpolation. + output_heights : None | list + Heights to interpolate to. If None this defaults to [10, 40, 80, + 100, 120, 160, 200]. + variables : list + List of variables to interpolate. If None this defaults to u and v. + max_log_height : int + Maximum height to use for log interpolation. Above this linear + interpolation will be used. + max_workers : None | int + Number of workers to use for interpolating over timesteps. + overwrite : bool + Whether to overwrite existing files. + """ + log_interp = cls( + infile, + outfile, + output_heights=output_heights, + variables=variables, + max_log_height=max_log_height, + ) + if os.path.exists(outfile) and not overwrite: + logger.info( + f'{outfile} already exists and overwrite=False. ' 'Skipping.' + ) + else: + log_interp.load() + log_interp.interpolate_vars(max_workers=max_workers) + log_interp.save_output() + + @classmethod + def run_multiple( + cls, + infiles, + out_dir, + output_heights=None, + max_log_height=100, + overwrite=False, + variables=None, + max_workers=None, + ): + """Run interpolation and save output + + Parameters + ---------- + infiles : str | list + Glob-able path or to ERA5 data or list of files to use for + windspeed log interpolation. Assumed to contain zg, orog, and at + least u/v at 10m. + out_dir : str + Path to save output directory after log interpolation. + output_heights : None | list + Heights to interpolate to. If None this defaults to [40, 80]. + max_log_height : int + Maximum height to use for log interpolation. Above this linear + interpolation will be used. + variables : list + List of variables to interpolate. If None this defaults to u and v. + overwrite : bool + Whether to overwrite existing outfile. + max_workers : None | int + Number of workers to use for interpolating over timesteps. + """ + futures = [] + if isinstance(infiles, str): + infiles = glob(infiles) + if max_workers == 1: + for _, file in enumerate(infiles): + outfile = os.path.basename(file).replace( + '.nc', '_all_interp.nc' + ) + outfile = os.path.join(out_dir, outfile) + cls.run( + file, + outfile, + output_heights=output_heights, + max_log_height=max_log_height, + overwrite=overwrite, + variables=variables, + ) + + else: + with ThreadPoolExecutor(max_workers=max_workers) as exe: + for i, file in enumerate(infiles): + outfile = os.path.basename(file).replace( + '.nc', '_all_interp.nc' + ) + outfile = os.path.join(out_dir, outfile) + futures.append( + exe.submit( + cls.run, + file, + outfile, + output_heights=output_heights, + variables=variables, + max_log_height=max_log_height, + overwrite=overwrite, + ) + ) + logger.info( + f'{i + 1} of {len(infiles)} futures submitted.' + ) + for i, future in enumerate(as_completed(futures)): + future.result() + logger.info(f'{i + 1} of {len(futures)} futures complete.') + + @classmethod + def pbl_interp_to_height( + cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100, + ): + """Fit ws log law to data below max_log_height. + + Parameters + ---------- + lev_array : ndarray + 1D Array of height values corresponding to the wrf source + data in the same shape as var_array. + var_array : ndarray + 1D Array of variable data, for example u-wind in a 1D array of + shape + levels : float | list + level or levels to interpolate to (e.g. final desired hub heights + above surface elevation) + fixed_level_mask : ndarray | None + Optional mask to use only fixed levels. Fixed levels are those that + were not computed from pressure levels but instead added along with + wind components at explicit heights (e.g u_10m, v_10m, u_100m, + v_100m) + max_log_height : int + Max height for using log interpolation. + + Returns + ------- + values : ndarray + Array of interpolated windspeed values below max_log_height. + good : bool + Check if log interpolation went without issue. + """ + + def ws_log_profile(z, a, b): + return a * np.log(z) + b + + lev_array_samp = lev_array.copy() + var_array_samp = var_array.copy() + if fixed_level_mask is not None: + lev_array_samp = lev_array_samp[fixed_level_mask] + var_array_samp = var_array_samp[fixed_level_mask] + + good = True + levels = np.array(levels) + lev_mask = (0 < levels) & (levels <= max_log_height) + var_mask = (0 < lev_array_samp) & (lev_array_samp <= max_log_height) + + try: + popt, _ = curve_fit( + ws_log_profile, + lev_array_samp[var_mask], + var_array_samp[var_mask], + ) + log_ws = ws_log_profile(levels[lev_mask], *popt) + except Exception as e: + msg = ( + 'Log interp failed with (h, ws) = ' + f'({lev_array_samp[var_mask]}, ' + f'{var_array_samp[var_mask]}). {e} ' + 'Using linear interpolation.' + ) + good = False + logger.warning(msg) + warn(msg) + log_ws = interp1d( + lev_array[var_mask], + var_array[var_mask], + fill_value='extrapolate', + )(levels[lev_mask]) + return log_ws, good + + @classmethod + def _interp_var_to_height( + cls, + lev_array, + var_array, + levels, + fixed_level_mask=None, + max_log_height=100, + ): + """Fit ws log law to wind data below max_log_height and linearly + interpolate data above. Linearly interpolate non wind data. + + Parameters + ---------- + lev_array : ndarray + 1D Array of height values corresponding to the wrf source + data in the same shape as var_array. + var_array : ndarray + 1D Array of variable data, for example u-wind in a 1D array of + shape + levels : float | list + level or levels to interpolate to (e.g. final desired hub heights + above surface elevation) + fixed_level_mask : ndarray | None + Optional mask to use only fixed levels. Fixed levels are those that + were not computed from pressure levels but instead added along with + wind components at explicit heights (e.g u_10m, v_10m, u_100m, + v_100m) + max_log_height : int + Max height for using log interpolation. + + Returns + ------- + values : ndarray + Array of interpolated data values at the requested heights. + good : bool + Check if interpolation went without issue. + """ + levels = np.array(levels) + + log_ws = None + lin_ws = None + good = True + + hgt_check = any(levels < max_log_height) and any( + lev_array < max_log_height + ) + if hgt_check: + log_ws, good = cls.pbl_interp_to_height( + lev_array, + var_array, + levels, + fixed_level_mask=fixed_level_mask, + max_log_height=max_log_height, + ) + + if any(levels > max_log_height): + lev_mask = levels >= max_log_height + var_mask = lev_array >= max_log_height + if len(lev_array[var_mask]) > 1: + lin_ws = interp1d( + lev_array[var_mask], + var_array[var_mask], + fill_value='extrapolate', + )(levels[lev_mask]) + elif len(lev_array) > 1: + msg = ( + 'Requested interpolation levels are outside the ' + f'available range: lev_array={lev_array}, ' + f'levels={levels}. Using linear extrapolation.' + ) + lin_ws = interp1d( + lev_array, var_array, fill_value='extrapolate' + )(levels[lev_mask]) + good = False + logger.warning(msg) + warn(msg) + else: + msg = ( + 'Data seems to be all NaNs. Something may have gone ' + 'wrong during download.' + ) + raise OSError(msg) + + if log_ws is not None and lin_ws is not None: + out = np.concatenate([log_ws, lin_ws]) + + if log_ws is not None and lin_ws is None: + out = log_ws + + if lin_ws is not None and log_ws is None: + out = lin_ws + + if log_ws is None and lin_ws is None: + msg = ( + f'No interpolation was performed for lev_array={lev_array} ' + f'and levels={levels}' + ) + raise RuntimeError(msg) + + return out, good + + @classmethod + def _get_timestep_interp_input(cls, lev_array, var_array, idt): + """Get interpolation input for given timestep + + Parameters + ---------- + lev_array : ndarray + 1D Array of height values corresponding to the wrf source + data in the same shape as var_array. + var_array : ndarray + 1D Array of variable data, for example u-wind in a 1D array of + shape + idt : int + Time index to interpolate + + Returns + ------- + h_t : ndarray + 1D array of height values for the requested time + v_t : ndarray + 1D array of variable data for the requested time + mask : ndarray + 1D array of bool values masking nans and heights < 0 + + """ + array_shape = var_array.shape + shape = (array_shape[-3], np.product(array_shape[-2:])) + h_t = lev_array[idt].reshape(shape).T + var_t = var_array[idt].reshape(shape).T + mask = ~np.isnan(h_t) & ~np.isnan(var_t) + + return h_t, var_t, mask + + @classmethod + def interp_single_ts( + cls, + hgt_t, + var_t, + mask, + levels, + fixed_level_mask=None, + max_log_height=100, + ): + """Perform interpolation for a single timestep specified by the index + idt + + Parameters + ---------- + hgt_t : ndarray + 1D Array of height values for a specific time. + var_t : ndarray + 1D Array of variable data for a specific time. + mask : ndarray + 1D Array of bool values to mask out nans and heights below 0. + levels : float | list + level or levels to interpolate to (e.g. final desired hub heights + above surface elevation) + fixed_level_mask : ndarray | None + Optional mask to use only fixed levels. Fixed levels are those + that were not computed from pressure levels but instead added along + with wind components at explicit heights (e.g u_10m, v_10m, u_100m, + v_100m) + max_log_height : int + Max height for using log interpolation. + + Returns + ------- + out_array : ndarray + Array of interpolated values. + """ + + # Interp each vertical column of height and var to requested levels + zip_iter = zip(hgt_t, var_t, mask) + out_array = [] + checks = [] + for h, var, mask in zip_iter: + val, check = cls._interp_var_to_height( + h[mask], + var[mask], + levels, + fixed_level_mask=fixed_level_mask[mask], + max_log_height=max_log_height, + ) + out_array.append(val) + checks.append(check) + return np.array(out_array), np.array(checks) + + @classmethod + def interp_var_to_height( + cls, + var_array, + lev_array, + levels, + fixed_level_mask=None, + max_log_height=100, + max_workers=None, + ): + """Interpolate data array to given level(s) based on h_array. + Interpolation is done using windspeed log profile and is done for every + 'z' column of [var, h] data. + + Parameters + ---------- + var_array : ndarray + Array of variable data, for example u-wind in a 4D array of shape + (time, vertical, lat, lon) + lev_array : ndarray + Array of height values corresponding to the wrf source + data in the same shape as var_array. lev_array should be + the geopotential height corresponding to every var_array index + relative to the surface elevation (subtract the elevation at the + surface from the geopotential height) + levels : float | list + level or levels to interpolate to (e.g. final desired hub heights + above surface elevation) + fixed_level_mask : ndarray | None + Optional mask to use only fixed levels. Fixed levels are those + that were not computed from pressure levels but instead added along + with wind components at explicit heights (e.g u_10m, v_10m, u_100m, + v_100m) + max_log_height : int + Max height for using log interpolation. + max_workers : None | int + Number of workers to use for interpolating over timesteps. + + Returns + ------- + out_array : ndarray + Array of interpolated values. + """ + lev_array, levels = Interpolator.prep_level_interp( + var_array, lev_array, levels + ) + + array_shape = var_array.shape + + # Flatten h_array and var_array along lat, long axis + shape = (len(levels), array_shape[-4], np.product(array_shape[-2:])) + out_array = np.zeros(shape, dtype=np.float32).T + total_checks = [] + + # iterate through time indices + futures = {} + if max_workers == 1: + for idt in range(array_shape[0]): + h_t, v_t, mask = cls._get_timestep_interp_input( + lev_array, var_array, idt + ) + out, checks = cls.interp_single_ts( + h_t, + v_t, + mask, + levels=levels, + fixed_level_mask=fixed_level_mask, + max_log_height=max_log_height, + ) + out_array[:, idt, :] = out + total_checks.append(checks) + + logger.info( + f'{idt + 1} of {array_shape[0]} timesteps finished.' + ) + + else: + with ProcessPoolExecutor(max_workers=max_workers) as exe: + for idt in range(array_shape[0]): + h_t, v_t, mask = cls._get_timestep_interp_input( + lev_array, var_array, idt + ) + future = exe.submit( + cls.interp_single_ts, + h_t, + v_t, + mask, + levels=levels, + fixed_level_mask=fixed_level_mask, + max_log_height=max_log_height, + ) + futures[future] = idt + logger.info( + f'{idt + 1} of {array_shape[0]} futures submitted.' + ) + for i, future in enumerate(as_completed(futures)): + out, checks = future.result() + out_array[:, futures[future], :] = out + total_checks.append(checks) + logger.info(f'{i + 1} of {len(futures)} futures complete.') + + total_checks = np.concatenate(total_checks) + good_count = total_checks.sum() + total_count = len(total_checks) + logger.info( + 'Percent of points interpolated without issue: ' + f'{100 * good_count / total_count:.2f}' + ) + + # Reshape out_array + if isinstance(levels, (float, np.float32, int)): + shape = (1, array_shape[-4], array_shape[-2], array_shape[-1]) + out_array = out_array.T.reshape(shape) + else: + shape = ( + len(levels), + array_shape[-4], + array_shape[-2], + array_shape[-1], + ) + out_array = out_array.T.reshape(shape) + + return out_array diff --git a/sup3r/utilities/interpolation.py b/sup3r/utilities/interpolation.py index 77d3439da..47e1650db 100644 --- a/sup3r/utilities/interpolation.py +++ b/sup3r/utilities/interpolation.py @@ -48,8 +48,9 @@ def calc_height(cls, data, raster_index, time_slice=slice(None)): # Terrain Height (m) hgt = data['HGT'][(time_slice, *tuple(raster_index))] if gp.shape != hgt.shape: - hgt = np.repeat(np.expand_dims(hgt, axis=1), gp.shape[-3], - axis=1) + hgt = np.repeat( + np.expand_dims(hgt, axis=1), gp.shape[-3], axis=1 + ) hgt = gp / 9.81 - hgt del gp @@ -65,11 +66,15 @@ def calc_height(cls, data, raster_index, time_slice=slice(None)): del gp else: - msg = ('Need either PHB/PH/HGT or zg/orog in data to perform ' - 'height interpolation') + msg = ( + 'Need either PHB/PH/HGT or zg/orog in data to perform ' + 'height interpolation' + ) raise ValueError(msg) - logger.debug('Spatiotemporally averaged height levels: ' - f'{list(np.nanmean(np.array(hgt), axis=(0, 2, 3)))}') + logger.debug( + 'Spatiotemporally averaged height levels: ' + f'{list(np.nanmean(np.array(hgt), axis=(0, 2, 3)))}' + ) return np.array(hgt) @classmethod @@ -169,9 +174,9 @@ def calc_pressure(cls, data, var, raster_index, time_slice=slice(None)): return p_array @classmethod - def interp_to_level(cls, var_array, lev_array, levels): - """Interpolate var_array to given level(s) based on h_array. - Interpolation is linear and done for every 'z' column of [var, h] data. + def prep_level_interp(cls, var_array, lev_array, levels): + """Prepare var_array interpolation. Check level ranges and add noise to + mask locations. Parameters ---------- @@ -191,17 +196,24 @@ def interp_to_level(cls, var_array, lev_array, levels): Returns ------- - out_array : ndarray - Array of interpolated values. + lev_array : ndarray + Array of levels with noise added to mask locations. + levels : list + List of levels to interpolate to. """ - msg = ('Input arrays must be the same shape.' - f'\nvar_array: {var_array.shape}' - f'\nh_array: {lev_array.shape}') + msg = ( + 'Input arrays must be the same shape.' + f'\nvar_array: {var_array.shape}' + f'\nh_array: {lev_array.shape}' + ) assert var_array.shape == lev_array.shape, msg - levels = ([levels] if isinstance(levels, (int, float, np.float32)) - else levels) + levels = ( + [levels] + if isinstance(levels, (int, float, np.float32)) + else levels + ) if np.isnan(lev_array).all(): msg = 'All pressure level height data is NaN!' @@ -210,14 +222,18 @@ def interp_to_level(cls, var_array, lev_array, levels): nans = np.isnan(lev_array) logger.debug('Level array shape: {}'.format(lev_array.shape)) - bad_min = min(levels) < lev_array[:, 0, :, :] - bad_max = max(levels) > lev_array[:, -1, :, :] + + lowest_height = np.min(lev_array[0, ...]) + highest_height = np.max(lev_array[0, ...]) + bad_min = min(levels) < lowest_height + bad_max = max(levels) > highest_height if nans.any(): - msg = ('Approximately {:.2f}% of the vertical level ' - 'array is NaN. Data will be interpolated or extrapolated ' - 'past these NaN values.' - .format(100 * nans.sum() / nans.size)) + msg = ( + 'Approximately {:.2f}% of the vertical level ' + 'array is NaN. Data will be interpolated or extrapolated ' + 'past these NaN values.'.format(100 * nans.sum() / nans.size) + ) logger.warning(msg) warn(msg) @@ -226,41 +242,78 @@ def interp_to_level(cls, var_array, lev_array, levels): # does not correspond to the lowest or highest height. Interpolation # can be performed without issue in this case. if bad_min.any(): - msg = ('Approximately {:.2f}% of the lowest vertical levels ' - '(maximum value of {:.3f}, minimum value of {:.3f}) ' - 'were greater than the minimum requested level: {}' - .format(100 * bad_min.sum() / bad_min.size, - lev_array[:, 0, :, :].max(), - lev_array[:, 0, :, :].min(), min(levels))) + msg = ( + 'Approximately {:.2f}% of the lowest vertical levels ' + '(maximum value of {:.3f}, minimum value of {:.3f}) ' + 'were greater than the minimum requested level: {}'.format( + 100 * bad_min.sum() / bad_min.size, + lev_array[:, 0, :, :].max(), + lev_array[:, 0, :, :].min(), + min(levels), + ) + ) logger.warning(msg) warn(msg) if bad_max.any(): - msg = ('Approximately {:.2f}% of the highest vertical levels ' - '(minimum value of {:.3f}, maximum value of {:.3f}) ' - 'were lower than the maximum requested level: {}' - .format(100 * bad_max.sum() / bad_max.size, - lev_array[:, -1, :, :].min(), - lev_array[:, -1, :, :].max(), - max(levels))) + msg = ( + 'Approximately {:.2f}% of the highest vertical levels ' + '(minimum value of {:.3f}, maximum value of {:.3f}) ' + 'were lower than the maximum requested level: {}'.format( + 100 * bad_max.sum() / bad_max.size, + lev_array[:, -1, :, :].min(), + lev_array[:, -1, :, :].max(), + max(levels), + ) + ) logger.warning(msg) warn(msg) - array_shape = var_array.shape - - # Flatten h_array and var_array along lat, long axis - shape = (len(levels), array_shape[-4], np.product(array_shape[-2:])) - out_array = np.zeros(shape, dtype=np.float32).T - # if multiple vertical levels have identical heights at the desired # interpolation level, interpolation to that value will fail because # linear slope will be NaN. This is most common if you have multiple # pressure levels at zero height at the surface in the case that the # data didnt provide underground data. for level in levels: - mask = (lev_array == level) + mask = lev_array == level lev_array[mask] += np.random.uniform(-1e-5, 0, size=mask.sum()) + return lev_array, levels + + @classmethod + def interp_to_level(cls, var_array, lev_array, levels): + """Interpolate var_array to given level(s) based on lev_array. + Interpolation is linear and done for every 'z' column of [var, h] data. + + Parameters + ---------- + var_array : ndarray + Array of variable data, for example u-wind in a 4D array of shape + (time, vertical, lat, lon) + lev_array : ndarray + Array of height or pressure values corresponding to the wrf source + data in the same shape as var_array. If this is height and the + requested levels are hub heights above surface, lev_array should be + the geopotential height corresponding to every var_array index + relative to the surface elevation (subtract the elevation at the + surface from the geopotential height) + levels : float | list + level or levels to interpolate to (e.g. final desired hub heights + above surface elevation) + + Returns + ------- + out_array : ndarray + Array of interpolated values. + """ + lev_array, levels = cls.prep_level_interp(var_array, lev_array, levels) + + array_shape = var_array.shape + + # Flatten h_array and var_array along lat, long axis + shape = (len(levels), array_shape[-4], np.product(array_shape[-2:])) + out_array = np.zeros(shape, dtype=np.float32).T + # iterate through time indices for idt in range(array_shape[0]): shape = (array_shape[-3], np.product(array_shape[-2:])) @@ -271,23 +324,34 @@ def interp_to_level(cls, var_array, lev_array, levels): # Interp each vertical column of height and var to requested levels zip_iter = zip(h_tmp, var_tmp, not_nan) out_array[:, idt, :] = np.array( - [interp1d(h[mask], var[mask], fill_value='extrapolate')(levels) - for h, var, mask in zip_iter], dtype=np.float32) + [ + interp1d(h[mask], var[mask], fill_value='extrapolate')( + levels + ) + for h, var, mask in zip_iter + ], + dtype=np.float32, + ) # Reshape out_array if isinstance(levels, (float, np.float32, int)): shape = (1, array_shape[-4], array_shape[-2], array_shape[-1]) out_array = out_array.T.reshape(shape) else: - shape = (len(levels), array_shape[-4], array_shape[-2], - array_shape[-1]) + shape = ( + len(levels), + array_shape[-4], + array_shape[-2], + array_shape[-1], + ) out_array = out_array.T.reshape(shape) return out_array @classmethod - def interp_var_to_height(cls, data, var, raster_index, heights, - time_slice=slice(None)): + def interp_var_to_height( + cls, data, var, raster_index, heights, time_slice=slice(None) + ): """Interpolate var_array to given level(s) based on h_array. Interpolation is linear and done for every 'z' column of [var, h] data. @@ -313,8 +377,10 @@ def interp_var_to_height(cls, data, var, raster_index, heights, raster_index = [0, *raster_index] logger.debug(f'Interpolating {var} to heights (meters): {heights}') hgt = cls.calc_height(data, raster_index, time_slice) - logger.info(f'Computed height array with min/max: {np.nanmin(hgt)} / ' - f'{np.nanmax(hgt)}') + logger.info( + f'Computed height array with min/max: {np.nanmin(hgt)} / ' + f'{np.nanmax(hgt)}' + ) if data[var].dims == ('plev',): arr = np.array(data[var]) arr = np.expand_dims(arr, axis=(0, 2, 3)) @@ -328,8 +394,9 @@ def interp_var_to_height(cls, data, var, raster_index, heights, return cls.interp_to_level(arr, hgt, heights)[0] @classmethod - def interp_var_to_pressure(cls, data, var, raster_index, pressures, - time_slice=slice(None)): + def interp_var_to_pressure( + cls, data, var, raster_index, pressures, time_slice=slice(None) + ): """Interpolate var_array to given level(s) based on h_array. Interpolation is linear and done for every 'z' column of [var, h] data. @@ -362,5 +429,6 @@ def interp_var_to_pressure(cls, data, var, raster_index, pressures, p_levels = cls.calc_pressure(data, var, raster_index, time_slice) - return cls.interp_to_level(arr[:, ::-1], p_levels[:, ::-1], - pressures)[0] + return cls.interp_to_level(arr[:, ::-1], p_levels[:, ::-1], pressures)[ + 0 + ] diff --git a/sup3r/utilities/pytest.py b/sup3r/utilities/pytest.py index 879539921..7c8581c74 100644 --- a/sup3r/utilities/pytest.py +++ b/sup3r/utilities/pytest.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- """Utilities used for pytests""" import os + import numpy as np import xarray as xr @@ -13,6 +14,8 @@ def make_fake_nc_files(td, input_file, n_files): Parameters ---------- + td : str + Temporary directory input_file : str File to use as template for all dummy files n_files : int @@ -23,27 +26,66 @@ def make_fake_nc_files(td, input_file, n_files): fake_files : list List of dummy files """ - fake_dates = [f'2014-10-01_{str(i).zfill(2)}_00_00' - for i in range(n_files)] - fake_times = [f'2014-10-01 {str(i).zfill(2)}:00:00' - for i in range(n_files)] + fake_dates = [ + f'2014-10-01_{str(i).zfill(2)}_00_00' for i in range(n_files) + ] + fake_times = [ + f'2014-10-01 {str(i).zfill(2)}:00:00' for i in range(n_files) + ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): input_dset = xr.open_dataset(input_file) with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array([fake_times[i].encode('ASCII')], - dtype='|S19') + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19' + ) dset['XTIME'][:] = i dset.to_netcdf(fake_files[i]) return fake_files +def make_fake_multi_time_nc_files(td, input_file, n_steps, n_files): + """Make dummy nc file with multiple timesteps + + Parameters + ---------- + td : str + Temporary directory + input_file : str + File to use as template for timesteps in dummy file + n_steps : int + Number of timesteps across all files + n_files : int + Number of files to split all timsteps across + + Returns + ------- + fake_file : str + multi timestep dummy file + """ + fake_files = make_fake_nc_files(td, input_file, n_steps) + fake_files = np.array_split(fake_files, n_files) + dummy_files = [] + for i, files in enumerate(fake_files): + dummy_file = os.path.join( + td, f'multi_timestep_file_{str(i).zfill(3)}.nc' + ) + dummy_files.append(dummy_file) + with xr.open_mfdataset( + files, combine='nested', concat_dim='Time' + ) as dset: + dset.to_netcdf(dummy_file) + return dummy_files + + def make_fake_era_files(td, input_file, n_files): """Make dummy era files with increasing times. ERA files have a different naming convention than WRF. Parameters ---------- + td : str + Temporary directory input_file : str File to use as template for all dummy files n_files : int @@ -54,16 +96,19 @@ def make_fake_era_files(td, input_file, n_files): fake_files : list List of dummy files """ - fake_dates = [f'2014-10-01_{str(i).zfill(2)}_00_00' - for i in range(n_files)] - fake_times = [f'2014-10-01 {str(i).zfill(2)}:00:00' - for i in range(n_files)] + fake_dates = [ + f'2014-10-01_{str(i).zfill(2)}_00_00' for i in range(n_files) + ] + fake_times = [ + f'2014-10-01 {str(i).zfill(2)}:00:00' for i in range(n_files) + ] fake_files = [os.path.join(td, f'input_{date}') for date in fake_dates] for i in range(n_files): input_dset = xr.open_dataset(input_file) with xr.Dataset(input_dset) as dset: - dset['Times'][:] = np.array([fake_times[i].encode('ASCII')], - dtype='|S19') + dset['Times'][:] = np.array( + [fake_times[i].encode('ASCII')], dtype='|S19' + ) dset['XTIME'][:] = i dset = dset.rename({'U': 'u', 'V': 'v'}) dset.to_netcdf(fake_files[i]) @@ -122,8 +167,9 @@ def make_fake_h5_chunks(td): gids = np.arange(np.product(shape[:2])) gids = gids.reshape(shape[:2]) - low_res_times = pd_date_range('20220101', '20220103', freq='3600s', - inclusive='left') + low_res_times = pd_date_range( + '20220101', '20220103', freq='3600s', inclusive='left' + ) t_slices_lr = [slice(0, 24), slice(24, None)] t_slices_hr = [slice(0, 48), slice(48, None)] @@ -131,22 +177,41 @@ def make_fake_h5_chunks(td): s_slices_lr = [slice(0, 5), slice(5, 10)] s_slices_hr = [slice(0, 25), slice(25, 50)] - out_pattern = os.path.join(td, 'fp_out_{i}_{j}_{k}.h5') + out_pattern = os.path.join(td, 'fp_out_{t}_{i}_{j}.h5') out_files = [] - for i, (slice_lr, slice_hr) in enumerate(zip(t_slices_lr, t_slices_hr)): - for j, (s1_lr, s1_hr) in enumerate(zip(s_slices_lr, s_slices_hr)): - for k, (s2_lr, s2_hr) in enumerate(zip(s_slices_lr, s_slices_hr)): - out_file = out_pattern.format(i=i, j=j, k=k) + for t, (slice_lr, slice_hr) in enumerate(zip(t_slices_lr, t_slices_hr)): + for i, (s1_lr, s1_hr) in enumerate(zip(s_slices_lr, s_slices_hr)): + for j, (s2_lr, s2_hr) in enumerate(zip(s_slices_lr, s_slices_hr)): + out_file = out_pattern.format( + t=str(t).zfill(3), + i=str(i).zfill(3), + j=str(j).zfill(3), + ) out_files.append(out_file) OutputHandlerH5.write_output( - data[s1_hr, s2_hr, slice_hr, :], features, - low_res_lat_lon[s1_lr, s2_lr], low_res_times[slice_lr], - out_file, meta_data=model_meta_data, max_workers=1, - gids=gids[s1_hr, s2_hr]) + data[s1_hr, s2_hr, slice_hr, :], + features, + low_res_lat_lon[s1_lr, s2_lr], + low_res_times[slice_lr], + out_file, + meta_data=model_meta_data, + max_workers=1, + gids=gids[s1_hr, s2_hr], + ) - out = (out_files, data, ws_true, wd_true, features, t_slices_lr, - t_slices_hr, s_slices_lr, s_slices_hr, low_res_lat_lon, - low_res_times) + out = ( + out_files, + data, + ws_true, + wd_true, + features, + t_slices_lr, + t_slices_hr, + s_slices_lr, + s_slices_hr, + low_res_lat_lon, + low_res_times, + ) return out @@ -181,17 +246,20 @@ def make_fake_cs_ratio_files(td, low_res_times, low_res_lat_lon, gan_meta): os.makedirs(chunk_dir) for idt, timestamp in enumerate(low_res_times): - fn = ('sup3r_chunk_{}_{}.h5' - .format(str(idt).zfill(6), str(0).zfill(6))) + fn = 'sup3r_chunk_{}_{}.h5'.format(str(idt).zfill(6), str(0).zfill(6)) out_file = os.path.join(chunk_dir, fn) fps.append(out_file) cs_ratio = np.random.uniform(0, 1, (20, 20, 1, 1)) cs_ratio = np.repeat(cs_ratio, 24, axis=2) - OutputHandlerH5.write_output(cs_ratio, ['clearsky_ratio'], - low_res_lat_lon, - [timestamp], - out_file, max_workers=1, - meta_data=gan_meta) + OutputHandlerH5.write_output( + cs_ratio, + ['clearsky_ratio'], + low_res_lat_lon, + [timestamp], + out_file, + max_workers=1, + meta_data=gan_meta, + ) return fps, fp_pattern diff --git a/sup3r/utilities/utilities.py b/sup3r/utilities/utilities.py index 03c2eb601..01a1d925a 100644 --- a/sup3r/utilities/utilities.py +++ b/sup3r/utilities/utilities.py @@ -26,6 +26,30 @@ logger = logging.getLogger(__name__) +def windspeed_log_law(z, a, b, c): + """Windspeed log profile. + + Parameters + ---------- + z : float + Height above ground in meters + a : float + Proportional to friction velocity + b : float + Related to zero-plane displacement in meters (height above the ground + at which zero mean wind speed is achieved as a result of flow obstacles + such as trees or buildings) + c : float + Proportional to stability term. + + Returns + ------- + ws : float + Value of windspeed at a given height. + """ + return a * np.log(z + b) + c + + def get_time_dim_name(filepath): """Get the name of the time dimension in the given file. This is specifically for netcdf files. diff --git a/tests/data/test_era_co_2012.nc b/tests/data/test_era_co_2012.nc new file mode 100644 index 000000000..695b50eee Binary files /dev/null and b/tests/data/test_era_co_2012.nc differ diff --git a/tests/data_handling/test_data_handling_h5.py b/tests/data_handling/test_data_handling_h5.py index ab7085f09..96c529ff3 100644 --- a/tests/data_handling/test_data_handling_h5.py +++ b/tests/data_handling/test_data_handling_h5.py @@ -99,7 +99,7 @@ def test_data_caching(): cache_pattern = os.path.join(td, 'cached_features_h5') handler = DataHandler(input_files[0], features, cache_pattern=cache_pattern, - overwrite_cache=True, + overwrite_cache=True, val_split=0.05, **dh_kwargs) assert handler.data is None @@ -115,6 +115,7 @@ def test_data_caching(): handler = DataHandler(input_files[0], features, cache_pattern=cache_pattern, overwrite_cache=True, load_cached=True, + val_split=0.05, **dh_kwargs) assert handler.data is not None assert handler.val_data is not None @@ -283,7 +284,8 @@ def test_spatiotemporal_normalization(): def test_data_extraction(): """Test data extraction class""" - handler = DataHandler(input_files[0], features, **dh_kwargs) + handler = DataHandler(input_files[0], features, val_split=0.05, + **dh_kwargs) assert handler.data.shape == (shape[0], shape[1], handler.data.shape[2], len(features)) assert handler.data.dtype == np.dtype(np.float32) @@ -293,7 +295,7 @@ def test_data_extraction(): def test_hr_coarsening(): """Test spatial coarsening of the high res field""" handler = DataHandler(input_files[0], features, hr_spatial_coarsen=2, - **dh_kwargs) + val_split=0.05, **dh_kwargs) assert handler.data.shape == (shape[0] // 2, shape[1] // 2, handler.data.shape[2], len(features)) assert handler.data.dtype == np.dtype(np.float32) @@ -304,7 +306,7 @@ def test_hr_coarsening(): if os.path.exists(cache_pattern): os.system(f'rm {cache_pattern}') handler = DataHandler(input_files[0], features, hr_spatial_coarsen=2, - cache_pattern=cache_pattern, + cache_pattern=cache_pattern, val_split=0.05, overwrite_cache=True, **dh_kwargs) assert handler.data is None handler.load_cached_data() @@ -322,7 +324,8 @@ def test_validation_batching(): for input_file in input_files: dh_kwargs_new = dh_kwargs.copy() dh_kwargs_new['sample_shape'] = (sample_shape[0], sample_shape[1], 1) - data_handler = DataHandler(input_file, features, **dh_kwargs_new) + data_handler = DataHandler(input_file, features, val_split=0.05, + **dh_kwargs_new) data_handlers.append(data_handler) batch_handler = SpatialBatchHandler([data_handler], **bh_kwargs) @@ -348,7 +351,8 @@ def test_temporal_coarsening(method, t_enhance): data_handlers = [] for input_file in input_files: - data_handler = DataHandler(input_file, features, **dh_kwargs) + data_handler = DataHandler(input_file, features, val_split=0.05, + **dh_kwargs) data_handlers.append(data_handler) max_workers = 1 bh_kwargs_new = bh_kwargs.copy() diff --git a/tests/data_handling/test_data_handling_nc.py b/tests/data_handling/test_data_handling_nc.py index 19e58ac93..59ada2965 100644 --- a/tests/data_handling/test_data_handling_nc.py +++ b/tests/data_handling/test_data_handling_nc.py @@ -298,7 +298,7 @@ def test_spatiotemporal_normalization(): def test_data_extraction(): """Test data extraction class""" - handler = DataHandler(INPUT_FILE, features, **dh_kwargs) + handler = DataHandler(INPUT_FILE, features, val_split=0.05, **dh_kwargs) assert handler.data.shape == (shape[0], shape[1], handler.data.shape[2], len(features)) assert handler.data.dtype == np.dtype(np.float32) diff --git a/tests/forward_pass/test_forward_pass.py b/tests/forward_pass/test_forward_pass.py index e712206c0..d6c54612f 100644 --- a/tests/forward_pass/test_forward_pass.py +++ b/tests/forward_pass/test_forward_pass.py @@ -15,7 +15,8 @@ from sup3r.models import LinearInterp, Sup3rGan, WindGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy from sup3r.preprocessing.data_handling import DataHandlerNC -from sup3r.utilities.pytest import make_fake_nc_files +from sup3r.utilities.pytest import (make_fake_nc_files, + make_fake_multi_time_nc_files) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') TARGET_COORD = (39.01, -105.15) @@ -93,6 +94,74 @@ def test_fwp_nc_cc(log=False): s_enhance * fwp_chunk_shape[1]) +def test_fwp_single_ts_vs_multi_ts_input_files(): + """Test forward pass handler output for spatial only model.""" + + fp_gen = os.path.join(CONFIG_DIR, 'spatial/gen_2x_2f.json') + fp_disc = os.path.join(CONFIG_DIR, 'spatial/disc.json') + + Sup3rGan.seed() + model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) + _ = model.generate(np.ones((4, 10, 10, len(FEATURES)))) + model.meta['training_features'] = FEATURES + model.meta['output_features'] = ['U_100m', 'V_100m'] + model.meta['s_enhance'] = 2 + model.meta['t_enhance'] = 1 + with tempfile.TemporaryDirectory() as td: + input_files = make_fake_nc_files(td, INPUT_FILE, 8) + out_dir = os.path.join(td, 's_gan') + model.save(out_dir) + + cache_pattern = os.path.join(td, 'cache') + out_files = os.path.join(td, 'out_{file_id}_single_ts.nc') + + max_workers = 1 + input_handler_kwargs = dict( + target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + cache_pattern=cache_pattern, + overwrite_cache=True) + single_ts_handler = ForwardPassStrategy( + input_files, model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers)) + single_ts_forward_pass = ForwardPass(single_ts_handler) + single_ts_forward_pass.run(single_ts_handler, node_index=0) + + input_files = make_fake_multi_time_nc_files(td, INPUT_FILE, 8, 2) + + cache_pattern = os.path.join(td, 'cache') + out_files = os.path.join(td, 'out_{file_id}_multi_ts.nc') + + max_workers = 1 + input_handler_kwargs = dict( + target=target, shape=shape, + temporal_slice=temporal_slice, + worker_kwargs=dict(max_workers=max_workers), + cache_pattern=cache_pattern, + overwrite_cache=True) + multi_ts_handler = ForwardPassStrategy( + input_files, model_kwargs={'model_dir': out_dir}, + fwp_chunk_shape=fwp_chunk_shape, + spatial_pad=1, temporal_pad=1, + input_handler_kwargs=input_handler_kwargs, out_pattern=out_files, + worker_kwargs=dict(max_workers=max_workers)) + multi_ts_forward_pass = ForwardPass(multi_ts_handler) + multi_ts_forward_pass.run(multi_ts_handler, node_index=0) + + kwargs = {'combine': 'nested', 'concat_dim': 'Time'} + with xr.open_mfdataset(single_ts_handler.out_files, + **kwargs) as single_ts: + with xr.open_mfdataset(multi_ts_handler.out_files, + **kwargs) as multi_ts: + for feat in model.meta['output_features']: + assert np.array_equal(single_ts[feat].values, + multi_ts[feat].values) + + def test_fwp_spatial_only(): """Test forward pass handler output for spatial only model.""" diff --git a/tests/output/test_output_handling.py b/tests/output/test_output_handling.py index 0365a793d..6f85c9e99 100644 --- a/tests/output/test_output_handling.py +++ b/tests/output/test_output_handling.py @@ -1,26 +1,31 @@ """Output method tests""" import json -import numpy as np import os -import tensorflow as tf import tempfile + +import numpy as np import pandas as pd +import tensorflow as tf +from rex import ResourceX, init_logger from sup3r import __version__ -from sup3r.postprocessing.file_handling import OutputHandlerNC, OutputHandlerH5 from sup3r.postprocessing.collection import Collector -from sup3r.utilities.utilities import invert_uv, transform_rotate_wind +from sup3r.postprocessing.file_handling import OutputHandlerH5, OutputHandlerNC from sup3r.utilities.pytest import make_fake_h5_chunks - -from rex import ResourceX, init_logger +from sup3r.utilities.utilities import invert_uv, transform_rotate_wind def test_get_lat_lon(): """Check that regridding works correctly""" low_res_lats = np.array([[1, 1, 1], [0, 0, 0]]) low_res_lons = np.array([[-120, -100, -80], [-120, -100, -80]]) - lat_lon = np.concatenate([np.expand_dims(low_res_lats, axis=-1), - np.expand_dims(low_res_lons, axis=-1)], axis=-1) + lat_lon = np.concatenate( + [ + np.expand_dims(low_res_lats, axis=-1), + np.expand_dims(low_res_lons, axis=-1), + ], + axis=-1, + ) shape = (4, 6) new_lat_lon = OutputHandlerNC.get_lat_lon(lat_lon, shape) @@ -50,14 +55,17 @@ def test_invert_uv(): """Make sure inverse uv transform returns inputs""" lats = np.array([[1, 1, 1], [0, 0, 0]]) lons = np.array([[-120, -100, -80], [-120, -100, -80]]) - lat_lon = np.concatenate([np.expand_dims(lats, axis=-1), - np.expand_dims(lons, axis=-1)], axis=-1) + lat_lon = np.concatenate( + [np.expand_dims(lats, axis=-1), np.expand_dims(lons, axis=-1)], axis=-1 + ) windspeed = np.random.rand(lat_lon.shape[0], lat_lon.shape[1], 5) winddirection = 360 * np.random.rand(lat_lon.shape[0], lat_lon.shape[1], 5) - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) ws, wd = invert_uv(u, v, lat_lon) @@ -65,9 +73,11 @@ def test_invert_uv(): assert np.allclose(winddirection, wd) lat_lon = lat_lon[::-1] - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) ws, wd = invert_uv(u, v, lat_lon) @@ -81,13 +91,15 @@ def test_invert_uv_inplace(): lats = np.array([[1, 1, 1], [0, 0, 0]]) lons = np.array([[-120, -100, -80], [-120, -100, -80]]) - lat_lon = np.concatenate([np.expand_dims(lats, axis=-1), - np.expand_dims(lons, axis=-1)], axis=-1) + lat_lon = np.concatenate( + [np.expand_dims(lats, axis=-1), np.expand_dims(lons, axis=-1)], axis=-1 + ) u = np.random.rand(lat_lon.shape[0], lat_lon.shape[1], 5) v = np.random.rand(lat_lon.shape[0], lat_lon.shape[1], 5) - data = np.concatenate([np.expand_dims(u, axis=-1), - np.expand_dims(v, axis=-1)], axis=-1) + data = np.concatenate( + [np.expand_dims(u, axis=-1), np.expand_dims(v, axis=-1)], axis=-1 + ) OutputHandlerH5.invert_uv_features(data, ['U_100m', 'V_100m'], lat_lon) ws, wd = invert_uv(u, v, lat_lon) @@ -96,8 +108,9 @@ def test_invert_uv_inplace(): assert np.allclose(data[..., 1], wd) lat_lon = lat_lon[::-1] - data = np.concatenate([np.expand_dims(u, axis=-1), - np.expand_dims(v, axis=-1)], axis=-1) + data = np.concatenate( + [np.expand_dims(u, axis=-1), np.expand_dims(v, axis=-1)], axis=-1 + ) OutputHandlerH5.invert_uv_features(data, ['U_100m', 'V_100m'], lat_lon) ws, wd = invert_uv(u, v, lat_lon) @@ -113,8 +126,19 @@ def test_h5_out_and_collect(): fp_out = os.path.join(td, 'out_combined.h5') out = make_fake_h5_chunks(td) - (out_files, data, ws_true, wd_true, features, _, - t_slices_hr, _, s_slices_hr, _, low_res_times) = out + ( + out_files, + data, + ws_true, + wd_true, + features, + _, + t_slices_hr, + _, + s_slices_hr, + _, + low_res_times, + ) = out Collector.collect(out_files, fp_out, features=features) with ResourceX(fp_out) as fh: @@ -132,15 +156,18 @@ def test_h5_out_and_collect(): if s1_idx == s2_idx == 0: combined_ti += list(fh_i.time_index) - ws_i = np.transpose(data[s1_hr, s2_hr, t_hr, 0], - axes=(2, 0, 1)) - wd_i = np.transpose(data[s1_hr, s2_hr, t_hr, 1], - axes=(2, 0, 1)) + ws_i = np.transpose( + data[s1_hr, s2_hr, t_hr, 0], axes=(2, 0, 1) + ) + wd_i = np.transpose( + data[s1_hr, s2_hr, t_hr, 1], axes=(2, 0, 1) + ) ws_i = ws_i.reshape(48, 625) wd_i = wd_i.reshape(48, 625) assert np.allclose(ws_i, fh_i['windspeed_100m'], atol=0.01) - assert np.allclose(wd_i, fh_i['winddirection_100m'], - atol=0.1) + assert np.allclose( + wd_i, fh_i['winddirection_100m'], atol=0.1 + ) for k, v in fh_i.global_attrs.items(): assert k in fh.global_attrs, k @@ -182,7 +209,7 @@ def test_h5_collect_mask(log=False): Collector.collect(out_files, fp_out, features=features) indices = np.arange(np.product(data.shape[:2])) - indices = indices[-len(indices) // 2:] + indices = indices[slice(-len(indices) // 2, None)] removed = [] for _ in range(10): removed.append(np.random.choice(indices)) @@ -193,9 +220,14 @@ def test_h5_collect_mask(log=False): mask_meta['gid'][:] = np.arange(len(mask_meta)) mask_meta.to_csv(mask_file, index=False) - Collector.collect(out_files, fp_out_mask, features=features, - target_final_meta_file=mask_file, - max_workers=1, join_times=False) + Collector.collect( + out_files, + fp_out_mask, + features=features, + target_final_meta_file=mask_file, + max_workers=1, + join_times=False, + ) with ResourceX(fp_out_mask) as fh: mask_meta = pd.read_csv(mask_file, dtype=np.float32) assert np.array_equal(mask_meta['gid'], fh.meta.index.values) @@ -203,5 +235,6 @@ def test_h5_collect_mask(log=False): assert np.array_equal(mask_meta['latitude'], fh.meta['latitude']) with ResourceX(fp_out) as fh_o: - assert np.array_equal(fh_o['windspeed_100m', :, mask_slice], - fh['windspeed_100m']) + assert np.array_equal( + fh_o['windspeed_100m', :, mask_slice], fh['windspeed_100m'] + ) diff --git a/tests/training/test_bias_correction.py b/tests/training/test_bias_correction.py index f33ba95ec..07a7d7282 100644 --- a/tests/training/test_bias_correction.py +++ b/tests/training/test_bias_correction.py @@ -13,6 +13,7 @@ from sup3r.bias.bias_transforms import local_linear_bc, monthly_local_linear_bc from sup3r.models import Sup3rGan from sup3r.pipeline.forward_pass import ForwardPass, ForwardPassStrategy +from sup3r.preprocessing.data_handling import DataHandlerNCforCC from sup3r.qa.qa import Sup3rQa FP_NSRDB = os.path.join(TEST_DATA_DIR, 'test_nsrdb_co_2018.h5') @@ -179,6 +180,7 @@ def test_linear_transform(): """Test the linear bc transform method""" calc = LinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', TARGET, SHAPE, bias_handler='DataHandlerNCforCC') + lat_lon = calc.bias_dh.lat_lon with tempfile.TemporaryDirectory() as td: fp_out = os.path.join(td, 'bc.h5') out = calc.run(knn=1, threshold=0.6, fill_extend=False, max_workers=1, @@ -187,7 +189,7 @@ def test_linear_transform(): adder = out['rsds_adder'] test_data = np.ones_like(scalar) with pytest.warns(): - out = local_linear_bc(test_data, 'rsds', fp_out, + out = local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, out_range=None) out = calc.run(knn=1, threshold=0.6, fill_extend=True, max_workers=1, @@ -195,7 +197,7 @@ def test_linear_transform(): scalar = out['rsds_scalar'] adder = out['rsds_adder'] test_data = np.ones_like(scalar) - out = local_linear_bc(test_data, 'rsds', fp_out, + out = local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, out_range=None) assert np.allclose(out, scalar + adder) @@ -205,15 +207,15 @@ def test_linear_transform(): out_mask = too_big | too_small assert out_mask.any() - out = local_linear_bc(test_data, 'rsds', fp_out, + out = local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, out_range=out_range) assert np.allclose(out[too_big], np.max(out_range)) assert np.allclose(out[too_small], np.min(out_range)) lr_slice = (slice(1, 2), slice(2, 3), slice(None)) - sliced_out = local_linear_bc(test_data[lr_slice], 'rsds', fp_out, - lr_padded_slice=lr_slice, + sliced_out = local_linear_bc(test_data[lr_slice], lat_lon[lr_slice], + 'rsds', fp_out, lr_padded_slice=lr_slice, out_range=out_range) assert np.allclose(out[lr_slice], sliced_out) @@ -223,6 +225,7 @@ def test_montly_linear_transform(): calc = MonthlyLinearCorrection(FP_NSRDB, FP_CC, 'ghi', 'rsds', TARGET, SHAPE, bias_handler='DataHandlerNCforCC') + lat_lon = calc.bias_dh.lat_lon _, base_ti = calc.get_base_data(calc.base_fps, calc.base_dset, 5, calc.base_handler, daily_reduction='avg') @@ -234,7 +237,7 @@ def test_montly_linear_transform(): adder = out['rsds_adder'] test_data = np.ones((scalar.shape[0], scalar.shape[1], len(base_ti))) with pytest.warns(): - out = monthly_local_linear_bc(test_data, 'rsds', fp_out, + out = monthly_local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, time_index=base_ti, temporal_avg=True, @@ -245,7 +248,7 @@ def test_montly_linear_transform(): truth = np.expand_dims(truth, axis=-1) assert np.allclose(truth, out) - out = monthly_local_linear_bc(test_data, 'rsds', fp_out, + out = monthly_local_linear_bc(test_data, lat_lon, 'rsds', fp_out, lr_padded_slice=None, time_index=base_ti, temporal_avg=False, @@ -293,6 +296,9 @@ def test_fwp_integration(): os.path.join(TEST_DATA_DIR, 'orog_test.nc'), os.path.join(TEST_DATA_DIR, 'zg_test.nc')] + lat_lon = DataHandlerNCforCC(input_files, features=[], target=target, + shape=shape).lat_lon + Sup3rGan.seed() model = Sup3rGan(fp_gen, fp_disc, learning_rate=1e-4) _ = model.generate(np.ones((4, 10, 10, 6, len(features)))) @@ -314,6 +320,8 @@ def test_fwp_integration(): f.create_dataset('U_100m_adder', data=adder) f.create_dataset('V_100m_scalar', data=scalar) f.create_dataset('V_100m_adder', data=adder) + f.create_dataset('latitude', data=lat_lon[..., 0]) + f.create_dataset('longitude', data=lat_lon[..., 1]) bias_correct_kwargs = {'U_100m': {'feature_name': 'U_100m', 'bias_fp': bias_fp}, @@ -369,6 +377,8 @@ def test_qa_integration(): os.path.join(TEST_DATA_DIR, 'orog_test.nc'), os.path.join(TEST_DATA_DIR, 'zg_test.nc')] + lat_lon = DataHandlerNCforCC(input_files, features=[]).lat_lon + with tempfile.TemporaryDirectory() as td: bias_fp = os.path.join(td, 'bc.h5') @@ -384,6 +394,8 @@ def test_qa_integration(): f.create_dataset('U_100m_adder', data=adder) f.create_dataset('V_100m_scalar', data=scalar) f.create_dataset('V_100m_adder', data=adder) + f.create_dataset('latitude', data=lat_lon[..., 0]) + f.create_dataset('longitude', data=lat_lon[..., 1]) qa_kw = {'s_enhance': 3, 't_enhance': 4, diff --git a/tests/training/test_train_gan.py b/tests/training/test_train_gan.py index 885ad1803..17bbfb75e 100644 --- a/tests/training/test_train_gan.py +++ b/tests/training/test_train_gan.py @@ -265,7 +265,12 @@ def test_train_st(n_epoch=2, log=False): batch_handler = BatchHandler([handler], batch_size=2, s_enhance=3, t_enhance=4, - n_batches=2) + n_batches=2, + worker_kwargs=dict(max_workers=1)) + + assert batch_handler.norm_workers == 1 + assert batch_handler.stats_workers == 1 + assert batch_handler.load_workers == 1 with tempfile.TemporaryDirectory() as td: # test that training works and reduces loss diff --git a/tests/utilities/test_utilities.py b/tests/utilities/test_utilities.py index af725fbf7..27c06fa96 100644 --- a/tests/utilities/test_utilities.py +++ b/tests/utilities/test_utilities.py @@ -1,29 +1,77 @@ # -*- coding: utf-8 -*- """pytests for general utilities""" -import numpy as np -from scipy.interpolate import interp1d -import pytest -import matplotlib.pyplot as plt import os import tempfile +import matplotlib.pyplot as plt +import numpy as np +import pytest +import xarray as xr from rex import Resource, init_logger +from scipy.interpolate import interp1d +from sup3r import TEST_DATA_DIR +from sup3r.postprocessing.collection import Collector from sup3r.postprocessing.file_handling import OutputHandler -from sup3r.utilities.utilities import (get_chunk_slices, - uniform_time_sampler, - weighted_time_sampler, - weighted_box_sampler, - uniform_box_sampler, - spatial_coarsening, - transform_rotate_wind, - st_interp) +from sup3r.utilities.interpolate_log_profile import LogLinInterpolator from sup3r.utilities.regridder import RegridOutput -from sup3r.postprocessing.collection import Collector -from sup3r import TEST_DATA_DIR - +from sup3r.utilities.utilities import ( + get_chunk_slices, + spatial_coarsening, + st_interp, + transform_rotate_wind, + uniform_box_sampler, + uniform_time_sampler, + weighted_box_sampler, + weighted_time_sampler, +) FP_WTK = os.path.join(TEST_DATA_DIR, 'test_wtk_co_2012.h5') +FP_ERA = os.path.join(TEST_DATA_DIR, 'test_era_co_2012.nc') + + +def test_log_interp(log=True): + """Make sure log interp generates reasonable output (e.g. between input + levels)""" + if log: + init_logger('sup3r', log_level='DEBUG') + with tempfile.TemporaryDirectory() as tmpdir: + outfile = f'{tmpdir}/uv_interp.nc' + infile = f'{tmpdir}/uv_input.nc' + tmp = xr.open_dataset(FP_ERA) + tmp = tmp.isel(time=slice(0, 100)) + tmp.to_netcdf(infile) + tmp.close() + LogLinInterpolator.run( + infile, + outfile, + output_heights={'u': [40], 'v': [40]}, + variables=['u', 'v'], + max_workers=1, + ) + + def between_check(first, mid, second): + return (first < mid < second) or (second < mid < first) + + out = xr.open_dataset(outfile) + input = xr.open_dataset(infile) + u_check = all( + between_check(lower, mid, higher) + for lower, mid, higher in zip( + input['u_10m'].values.flatten(), + out['u_40m'].values.flatten(), + input['u_100m'].values.flatten(), + ) + ) + v_check = all( + between_check(lower, mid, higher) + for lower, mid, higher in zip( + input['v_10m'].values.flatten(), + out['v_40m'].values.flatten(), + input['v_100m'].values.flatten(), + ) + ) + assert u_check and v_check def test_regridding(log=False): @@ -45,23 +93,29 @@ def test_regridding(log=False): target_meta = target_meta.sample(frac=1, random_state=0) target_meta.to_csv(shuffled_meta_path, index=False) - regrid_output = RegridOutput(source_files=[FP_WTK], - out_pattern=out_pattern, - target_meta=shuffled_meta_path, - heights=heights, k_neighbors=4, - worker_kwargs={'regrid_workers': 1, - 'query_workers': 1}, - incremental=True, n_chunks=10, - max_nodes=2) + regrid_output = RegridOutput( + source_files=[FP_WTK], + out_pattern=out_pattern, + target_meta=shuffled_meta_path, + heights=heights, + k_neighbors=4, + worker_kwargs={'regrid_workers': 1, 'query_workers': 1}, + incremental=True, + n_chunks=10, + max_nodes=2, + ) for node_index in range(regrid_output.nodes): regrid_output.run(node_index=node_index) - Collector.collect(regrid_output.out_files, - collect_file, - regrid_output.output_features, - target_final_meta_file=meta_path, - join_times=False, - n_writes=2, max_workers=1) + Collector.collect( + regrid_output.out_files, + collect_file, + regrid_output.output_features, + target_final_meta_file=meta_path, + join_times=False, + n_writes=2, + max_workers=1, + ) with Resource(collect_file) as out_res: for height in heights: ws_name = f'windspeed_{height}m' @@ -112,7 +166,6 @@ def test_weighted_box_sampler(): weights_3[5] = 0.5 for _ in range(100): - slice_1, _ = weighted_box_sampler(data, shape, weights_1) assert chunks[0][0] <= slice_1.start <= chunks[0][-1] @@ -120,8 +173,10 @@ def test_weighted_box_sampler(): assert chunks[-1][0] <= slice_2.start <= chunks[-1][-1] slice_3, _ = weighted_box_sampler(data, shape, weights_3) - assert (chunks[2][0] <= slice_3.start <= chunks[2][-1] - or chunks[5][0] <= slice_3.start <= chunks[5][-1]) + assert ( + chunks[2][0] <= slice_3.start <= chunks[2][-1] + or chunks[5][0] <= slice_3.start <= chunks[5][-1] + ) data = np.zeros((2, 100, 1)) shape = (2, 10) @@ -139,7 +194,6 @@ def test_weighted_box_sampler(): weights_3[5] = 0.5 for _ in range(100): - _, slice_1 = weighted_box_sampler(data, shape, weights_1) assert chunks[0][0] <= slice_1.start <= chunks[0][-1] @@ -147,8 +201,10 @@ def test_weighted_box_sampler(): assert chunks[-1][0] <= slice_2.start <= chunks[-1][-1] _, slice_3 = weighted_box_sampler(data, shape, weights_3) - assert (chunks[2][0] <= slice_3.start <= chunks[2][-1] - or chunks[5][0] <= slice_3.start <= chunks[5][-1]) + assert ( + chunks[2][0] <= slice_3.start <= chunks[2][-1] + or chunks[5][0] <= slice_3.start <= chunks[5][-1] + ) shape = (1, 1) weights = np.zeros(np.product(data.shape)) @@ -177,7 +233,6 @@ def test_weighted_time_sampler(): weights_3[5] = 0.5 for _ in range(100): - slice_1 = weighted_time_sampler(data, shape, weights_1) assert chunks[0][0] <= slice_1.start <= chunks[0][-1] @@ -185,8 +240,10 @@ def test_weighted_time_sampler(): assert chunks[-1][0] <= slice_2.start <= chunks[-1][-1] slice_3 = weighted_time_sampler(data, 10, weights_3) - assert (chunks[2][0] <= slice_3.start <= chunks[2][-1] - or chunks[5][0] <= slice_3.start <= chunks[5][-1]) + assert ( + chunks[2][0] <= slice_3.start <= chunks[2][-1] + or chunks[5][0] <= slice_3.start <= chunks[5][-1] + ) shape = 1 weights = np.zeros(data.shape[2]) @@ -249,15 +306,16 @@ def test_s_enhance_5D(s_enhance): for f in range(arr.shape[4]): for i_lr in range(coarse.shape[1]): for j_lr in range(coarse.shape[2]): - i_hr = i_lr * s_enhance i_hr = slice(i_hr, i_hr + s_enhance) j_hr = j_lr * s_enhance j_hr = slice(j_hr, j_hr + s_enhance) - assert np.allclose(coarse[o, i_lr, j_lr, t, f], - arr[o, i_hr, j_hr, t, f].mean()) + assert np.allclose( + coarse[o, i_lr, j_lr, t, f], + arr[o, i_hr, j_hr, t, f].mean(), + ) @pytest.mark.parametrize('s_enhance', [1, 2, 4, 5]) @@ -270,15 +328,15 @@ def test_s_enhance_4D(s_enhance): for f in range(arr.shape[3]): for i_lr in range(coarse.shape[1]): for j_lr in range(coarse.shape[2]): - i_hr = i_lr * s_enhance i_hr = slice(i_hr, i_hr + s_enhance) j_hr = j_lr * s_enhance j_hr = slice(j_hr, j_hr + s_enhance) - assert np.allclose(coarse[o, i_lr, j_lr, f], - arr[o, i_hr, j_hr, f].mean()) + assert np.allclose( + coarse[o, i_lr, j_lr, f], arr[o, i_hr, j_hr, f].mean() + ) @pytest.mark.parametrize('s_enhance', [1, 2, 4, 5]) @@ -291,15 +349,15 @@ def test_s_enhance_4D_no_obs(s_enhance): for f in range(arr.shape[3]): for i_lr in range(coarse.shape[0]): for j_lr in range(coarse.shape[1]): - i_hr = i_lr * s_enhance i_hr = slice(i_hr, i_hr + s_enhance) j_hr = j_lr * s_enhance j_hr = slice(j_hr, j_hr + s_enhance) - assert np.allclose(coarse[i_lr, j_lr, t, f], - arr[i_hr, j_hr, t, f].mean()) + assert np.allclose( + coarse[i_lr, j_lr, t, f], arr[i_hr, j_hr, t, f].mean() + ) @pytest.mark.parametrize('s_enhance', [1, 2, 4, 5]) @@ -311,31 +369,34 @@ def test_s_enhance_3D_no_obs(s_enhance): for f in range(arr.shape[2]): for i_lr in range(coarse.shape[0]): for j_lr in range(coarse.shape[1]): - i_hr = i_lr * s_enhance i_hr = slice(i_hr, i_hr + s_enhance) j_hr = j_lr * s_enhance j_hr = slice(j_hr, j_hr + s_enhance) - assert np.allclose(coarse[i_lr, j_lr, f], - arr[i_hr, j_hr, f].mean()) + assert np.allclose( + coarse[i_lr, j_lr, f], arr[i_hr, j_hr, f].mean() + ) def test_transform_rotate(): """Make sure inverse uv transform returns inputs""" lats = np.array([[1, 1, 1], [0, 0, 0]]) lons = np.array([[-120, -100, -80], [-120, -100, -80]]) - lat_lon = np.concatenate([np.expand_dims(lats, axis=-1), - np.expand_dims(lons, axis=-1)], axis=-1) + lat_lon = np.concatenate( + [np.expand_dims(lats, axis=-1), np.expand_dims(lons, axis=-1)], axis=-1 + ) windspeed = np.ones((lat_lon.shape[0], lat_lon.shape[1], 1)) # wd = 0 -> u = 0 and v = -1 winddirection = np.zeros((lat_lon.shape[0], lat_lon.shape[1], 1)) - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) u_target = np.zeros(u.shape) u_target[...] = 0 v_target = np.zeros(v.shape) @@ -348,9 +409,11 @@ def test_transform_rotate(): winddirection = np.zeros((lat_lon.shape[0], lat_lon.shape[1], 1)) winddirection[...] = 90 - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) u_target = np.zeros(u.shape) u_target[...] = -1 v_target = np.zeros(v.shape) @@ -363,9 +426,11 @@ def test_transform_rotate(): winddirection = np.zeros((lat_lon.shape[0], lat_lon.shape[1], 1)) winddirection[...] = 270 - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) u_target = np.zeros(u.shape) u_target[...] = 1 v_target = np.zeros(v.shape) @@ -378,9 +443,11 @@ def test_transform_rotate(): winddirection = np.zeros((lat_lon.shape[0], lat_lon.shape[1], 1)) winddirection[...] = 180 - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) u_target = np.zeros(u.shape) u_target[...] = 0 v_target = np.zeros(v.shape) @@ -393,9 +460,11 @@ def test_transform_rotate(): winddirection = np.zeros((lat_lon.shape[0], lat_lon.shape[1], 1)) winddirection[...] = 45 - u, v = transform_rotate_wind(np.array(windspeed, dtype=np.float32), - np.array(winddirection, dtype=np.float32), - lat_lon) + u, v = transform_rotate_wind( + np.array(windspeed, dtype=np.float32), + np.array(winddirection, dtype=np.float32), + lat_lon, + ) u_target = np.zeros(u.shape) u_target[...] = -1 / np.sqrt(2) v_target = np.zeros(v.shape) @@ -409,7 +478,7 @@ def test_st_interpolation(plot=False): """Test spatiotemporal linear interpolation""" X, Y, T = np.meshgrid(np.arange(10), np.arange(10), np.arange(1, 11)) - arr = 100 * np.exp(-((X - 5)**2 + (Y - 5)**2) / T) + arr = 100 * np.exp(-((X - 5) ** 2 + (Y - 5) ** 2) / T) s_interp = st_interp(arr, s_enhance=3, t_enhance=1)