diff --git a/utils/dd_to_csv.py b/utils/dd_to_csv.py index a9e11328..fc1db35f 100644 --- a/utils/dd_to_csv.py +++ b/utils/dd_to_csv.py @@ -47,9 +47,7 @@ def parse_parameter_values_from_file( while data[index].strip() == "": index += 1 - param_name = data[index].replace( - " ' '/", "" - ) # param_name is followed by this pattern + param_name = data[index].replace(" ' '/", "") # param_name is followed by this pattern index += 1 param_data = [] @@ -64,9 +62,7 @@ def parse_parameter_values_from_file( attributes = words[0].split(".") attributes = [a if " " in a else a.strip("'") for a in attributes] else: - raise ValueError( - f"Unexpected number of spaces in parameter value setting: {data[index]}" - ) + raise ValueError(f"Unexpected number of spaces in parameter value setting: {data[index]}") value = words[-1] param_data.append([*attributes, value]) @@ -106,9 +102,7 @@ def parse_parameter_values_from_file( text = words[1] set_data.add(tuple([*attributes, text])) else: - raise ValueError( - f"Unexpected number of spaces in set value setting: {data[index]}" - ) + raise ValueError(f"Unexpected number of spaces in set value setting: {data[index]}") index += 1 @@ -140,17 +134,11 @@ def save_data_with_headers( try: columns = headers_data[param_name] except KeyError: - raise ValueError( - f"Could not find mapping for {param_name} in mapping file." - ) + raise ValueError(f"Could not find mapping for {param_name} in mapping file.") for row in param_data: if len(row) != len(columns): - raise ValueError( - f"Mismatched number of columns for param {param_name} between data ({len(row)}) and mapping ({len(columns)})" - ) - df = pd.DataFrame( - data=np.asarray(param_data)[:, 0 : len(columns)], columns=columns - ) + raise ValueError(f"Mismatched number of columns for param {param_name} between data ({len(row)}) and mapping ({len(columns)})") + df = pd.DataFrame(data=np.asarray(param_data)[:, 0 : len(columns)], columns=columns) df.to_csv(os.path.join(save_dir, param_name + ".csv"), index=False) return @@ -171,9 +159,7 @@ def generate_headers_by_attr() -> Dict[str, List[str]]: return headers_by_attr -def convert_dd_to_tabular( - basedir: str, output_dir: str, headers_by_attr: Dict[str, List[str]] -) -> None: +def convert_dd_to_tabular(basedir: str, output_dir: str, headers_by_attr: Dict[str, List[str]]) -> None: dd_files = [p for p in Path(basedir).rglob("*.dd")] all_sets = defaultdict(list) @@ -219,12 +205,8 @@ def convert_dd_to_tabular( def main(arg_list: None | list[str] = None): args_parser = argparse.ArgumentParser() - args_parser.add_argument( - "input_dir", type=str, help="Input directory containing .dd files." - ) - args_parser.add_argument( - "output_dir", type=str, help="Output directory to save the .csv files in." - ) + args_parser.add_argument("input_dir", type=str, help="Input directory containing .dd files.") + args_parser.add_argument("output_dir", type=str, help="Output directory to save the .csv files in.") args = args_parser.parse_args(arg_list) convert_dd_to_tabular(args.input_dir, args.output_dir, generate_headers_by_attr()) diff --git a/utils/run_benchmarks.py b/utils/run_benchmarks.py index cd38bc30..afeff181 100644 --- a/utils/run_benchmarks.py +++ b/utils/run_benchmarks.py @@ -90,9 +90,7 @@ def run_gams_gdxdiff( return "Error: dd_files not in benchmark" # Copy GAMS scaffolding - scaffolding_folder = path.join( - path.dirname(path.realpath(__file__)), "..", "xl2times", "gams_scaffold" - ) + scaffolding_folder = path.join(path.dirname(path.realpath(__file__)), "..", "xl2times", "gams_scaffold") shutil.copytree(scaffolding_folder, out_folder, dirs_exist_ok=True) # Create link to TIMES source if not path.exists(path.join(out_folder, "source")): @@ -307,9 +305,7 @@ def run_all_benchmarks( # The rest of this script checks regressions against main # so skip it if we're already on main repo = git.Repo(".") # pyright: ignore - origin = ( - repo.remotes.origin if "origin" in repo.remotes else repo.remotes[0] - ) # don't assume remote is called 'origin' + origin = repo.remotes.origin if "origin" in repo.remotes else repo.remotes[0] # don't assume remote is called 'origin' origin.fetch("main") if "main" not in repo.heads: repo.create_head("main", origin.refs.main).set_tracking_branch(origin.refs.main) @@ -332,9 +328,7 @@ def run_all_benchmarks( result = parse_result(f.readlines()[-1]) # Use a fake runtime and GAMS result results_main.append((benchmark["name"], 999, "--", *result)) - print( - f"Skipped running on main. Using results from {path.join(benchmarks_folder, 'out-main')}" - ) + print(f"Skipped running on main. Using results from {path.join(benchmarks_folder, 'out-main')}") else: if repo.is_dirty(): @@ -396,23 +390,17 @@ def run_all_benchmarks( runtime_change = our_time - main_time print(f"Total runtime: {our_time:.2f}s (main: {main_time:.2f}s)") - print( - f"Change in runtime (negative == faster): {runtime_change:+.2f}s ({100 * runtime_change / main_time:+.1f}%)" - ) + print(f"Change in runtime (negative == faster): {runtime_change:+.2f}s ({100 * runtime_change / main_time:+.1f}%)") our_correct = df["Correct"].sum() main_correct = df["M Correct"].sum() correct_change = our_correct - main_correct - print( - f"Change in correct rows (higher == better): {correct_change:+d} ({100 * correct_change / main_correct:+.1f}%)" - ) + print(f"Change in correct rows (higher == better): {correct_change:+d} ({100 * correct_change / main_correct:+.1f}%)") our_additional_rows = df["Additional"].sum() main_additional_rows = df["M Additional"].sum() additional_change = our_additional_rows - main_additional_rows - print( - f"Change in additional rows: {additional_change:+d} ({100 * additional_change / main_additional_rows:+.1f}%)" - ) + print(f"Change in additional rows: {additional_change:+d} ({100 * additional_change / main_additional_rows:+.1f}%)") if len(accu_regressions) + len(addi_regressions) + len(time_regressions) > 0: print() @@ -511,10 +499,7 @@ def run_all_benchmarks( verbose=args.verbose, debug=args.debug, ) - print( - f"Ran {args.run} in {runtime:.2f}s. {acc}% ({cor} correct, {add} additional).\n" - f"GAMS: {gms}" - ) + print(f"Ran {args.run} in {runtime:.2f}s. {acc}% ({cor} correct, {add} additional).\n" f"GAMS: {gms}") else: run_all_benchmarks( benchmarks_folder, diff --git a/xl2times/__main__.py b/xl2times/__main__.py index 5f1f3f73..48f4d170 100644 --- a/xl2times/__main__.py +++ b/xl2times/__main__.py @@ -41,10 +41,7 @@ def convert_xl_to_times( result = excel.extract_tables(f) raw_tables.extend(result) pickle.dump(raw_tables, open(pickle_file, "wb")) - print( - f"Extracted {len(raw_tables)} tables," - f" {sum(table.dataframe.shape[0] for table in raw_tables)} rows" - ) + print(f"Extracted {len(raw_tables)} tables," f" {sum(table.dataframe.shape[0] for table in raw_tables)} rows") if stop_after_read: # Convert absolute paths to relative paths to enable comparing raw_tables.txt across machines @@ -60,14 +57,10 @@ def convert_xl_to_times( transforms.normalize_tags_columns, transforms.remove_fill_tables, transforms.validate_input_tables, - lambda config, tables, model: [ - transforms.remove_comment_cols(t) for t in tables - ], + lambda config, tables, model: [transforms.remove_comment_cols(t) for t in tables], transforms.remove_tables_with_formulas, # slow transforms.normalize_column_aliases, - lambda config, tables, model: [ - transforms.remove_comment_rows(config, t, model) for t in tables - ], + lambda config, tables, model: [transforms.remove_comment_rows(config, t, model) for t in tables], transforms.process_regions, transforms.generate_dummy_processes, transforms.process_time_slices, @@ -105,9 +98,7 @@ def convert_xl_to_times( transforms.fix_topology, transforms.complete_dictionary, transforms.convert_to_string, - lambda config, tables, model: dump_tables( - tables, os.path.join(output_dir, "merged_tables.txt") - ), + lambda config, tables, model: dump_tables(tables, os.path.join(output_dir, "merged_tables.txt")), lambda config, tables, model: produce_times_tables(config, tables), ] @@ -118,14 +109,10 @@ def convert_xl_to_times( output = transform(config, input, model) end_time = time.time() sep = "\n\n" + "=" * 80 + "\n" if verbose else "" - print( - f"{sep}transform {transform.__code__.co_name} took {end_time - start_time:.2f} seconds" - ) + print(f"{sep}transform {transform.__code__.co_name} took {end_time - start_time:.2f} seconds") if verbose: if isinstance(output, list): - for table in sorted( - output, key=lambda t: (t.tag, t.filename, t.sheetname, t.range) - ): + for table in sorted(output, key=lambda t: (t.tag, t.filename, t.sheetname, t.range)): print(table) elif isinstance(output, dict): for tag, df in output.items(): @@ -134,10 +121,7 @@ def convert_xl_to_times( input = output assert isinstance(output, dict) - print( - f"Conversion complete, {len(output)} tables produced," - f" {sum(df.shape[0] for df in output.values())} rows" - ) + print(f"Conversion complete, {len(output)} tables produced," f" {sum(df.shape[0] for df in output.values())} rows") return output @@ -154,31 +138,20 @@ def write_csv_tables(tables: Dict[str, DataFrame], output_dir: str): def read_csv_tables(input_dir: str) -> Dict[str, DataFrame]: result = {} for filename in os.listdir(input_dir): - result[filename.split(".")[0]] = pd.read_csv( - os.path.join(input_dir, filename), dtype=str - ) + result[filename.split(".")[0]] = pd.read_csv(os.path.join(input_dir, filename), dtype=str) return result -def compare( - data: Dict[str, DataFrame], ground_truth: Dict[str, DataFrame], output_dir: str -) -> str: - print( - f"Ground truth contains {len(ground_truth)} tables," - f" {sum(df.shape[0] for _, df in ground_truth.items())} rows" - ) +def compare(data: Dict[str, DataFrame], ground_truth: Dict[str, DataFrame], output_dir: str) -> str: + print(f"Ground truth contains {len(ground_truth)} tables," f" {sum(df.shape[0] for _, df in ground_truth.items())} rows") missing = set(ground_truth.keys()) - set(data.keys()) - missing_str = ", ".join( - [f"{x} ({ground_truth[x].shape[0]})" for x in sorted(missing)] - ) + missing_str = ", ".join([f"{x} ({ground_truth[x].shape[0]})" for x in sorted(missing)]) if len(missing) > 0: print(f"WARNING: Missing {len(missing)} tables: {missing_str}") additional_tables = set(data.keys()) - set(ground_truth.keys()) - additional_str = ", ".join( - [f"{x} ({data[x].shape[0]})" for x in sorted(additional_tables)] - ) + additional_str = ", ".join([f"{x} ({data[x].shape[0]})" for x in sorted(additional_tables)]) if len(additional_tables) > 0: print(f"WARNING: {len(additional_tables)} additional tables: {additional_str}") # Additional rows starts as the sum of lengths of additional tables produced @@ -186,9 +159,7 @@ def compare( total_gt_rows = 0 total_correct_rows = 0 - for table_name, gt_table in sorted( - ground_truth.items(), reverse=True, key=lambda t: len(t[1]) - ): + for table_name, gt_table in sorted(ground_truth.items(), reverse=True, key=lambda t: len(t[1])): if table_name in data: data_table = data[table_name] @@ -196,10 +167,7 @@ def compare( transformed_gt_cols = [col.split(".")[0] for col in gt_table.columns] data_cols = list(data_table.columns) if transformed_gt_cols != data_cols: - print( - f"WARNING: Table {table_name} header incorrect, was" - f" {data_cols}, should be {transformed_gt_cols}" - ) + print(f"WARNING: Table {table_name} header incorrect, was" f" {data_cols}, should be {transformed_gt_cols}") # both are in string form so can be compared without any issues gt_rows = set(tuple(row) for row in gt_table.to_numpy().tolist()) @@ -235,31 +203,20 @@ def compare( return result -def produce_times_tables( - config: datatypes.Config, input: Dict[str, DataFrame] -) -> Dict[str, DataFrame]: - print( - f"produce_times_tables: {len(input)} tables incoming," - f" {sum(len(value) for (_, value) in input.items())} rows" - ) +def produce_times_tables(config: datatypes.Config, input: Dict[str, DataFrame]) -> Dict[str, DataFrame]: + print(f"produce_times_tables: {len(input)} tables incoming," f" {sum(len(value) for (_, value) in input.items())} rows") result = {} used_tables = set() for mapping in config.times_xl_maps: if not mapping.xl_name in input: - print( - f"WARNING: Cannot produce table {mapping.times_name} because" - f" {mapping.xl_name} does not exist" - ) + print(f"WARNING: Cannot produce table {mapping.times_name} because" f" {mapping.xl_name} does not exist") else: used_tables.add(mapping.xl_name) df = input[mapping.xl_name].copy() # Filter rows according to filter_rows mapping: for filter_col, filter_val in mapping.filter_rows.items(): if filter_col not in df.columns: - print( - f"WARNING: Cannot produce table {mapping.times_name} because" - f" {mapping.xl_name} does not contain column {filter_col}" - ) + print(f"WARNING: Cannot produce table {mapping.times_name} because" f" {mapping.xl_name} does not contain column {filter_col}") # TODO break this loop and continue outer loop? filter = set(x.lower() for x in {filter_val}) i = df[filter_col].str.lower().isin(filter) @@ -283,12 +240,7 @@ def produce_times_tables( df.drop_duplicates(inplace=True) df.reset_index(drop=True, inplace=True) # TODO this is a hack. Use pd.StringDtype() so that notna() is sufficient - i = ( - df[mapping.times_cols[-1]].notna() - & (df != "None").all(axis=1) - & (df != "nan").all(axis=1) - & (df != "").all(axis=1) - ) + i = df[mapping.times_cols[-1]].notna() & (df != "None").all(axis=1) & (df != "nan").all(axis=1) & (df != "").all(axis=1) df = df.loc[i, mapping.times_cols] # Drop tables that are empty after filtering and dropping Nones: if len(df) == 0: @@ -297,16 +249,12 @@ def produce_times_tables( unused_tables = set(input.keys()) - used_tables if len(unused_tables) > 0: - print( - f"WARNING: {len(unused_tables)} unused tables: {', '.join(sorted(unused_tables))}" - ) + print(f"WARNING: {len(unused_tables)} unused tables: {', '.join(sorted(unused_tables))}") return result -def write_dd_files( - tables: Dict[str, DataFrame], config: datatypes.Config, output_dir: str -): +def write_dd_files(tables: Dict[str, DataFrame], config: datatypes.Config, output_dir: str): os.makedirs(output_dir, exist_ok=True) for item in os.listdir(output_dir): if item.endswith(".dd"): @@ -315,9 +263,7 @@ def write_dd_files( def convert_set(df: DataFrame): has_description = "TEXT" in df.columns for row in df.itertuples(index=False): - row_str = "'.'".join( - (str(x) for k, x in row._asdict().items() if k != "TEXT") - ) + row_str = "'.'".join((str(x) for k, x in row._asdict().items() if k != "TEXT")) desc = f" '{row.TEXT}'" if has_description else "" yield f"'{row_str}'{desc}\n" @@ -329,9 +275,7 @@ def convert_parameter(tablename: str, df: DataFrame): df = df.drop_duplicates(subset=query_columns, keep="last") for row in df.itertuples(index=False): val = row.VALUE - row_str = "'.'".join( - (str(x) for k, x in row._asdict().items() if k != "VALUE") - ) + row_str = "'.'".join((str(x) for k, x in row._asdict().items() if k != "VALUE")) yield f"'{row_str}' {val}\n" if row_str else f"{val}\n" sets = {m.times_name for m in config.times_xl_maps if "VALUE" not in m.col_map} @@ -412,11 +356,7 @@ def run(args) -> str | None: sys.exit(-1) elif len(args.input) == 1: assert os.path.isdir(args.input[0]) - input_files = [ - str(path) - for path in Path(args.input[0]).rglob("*") - if path.suffix in [".xlsx", ".xlsm"] and not path.name.startswith("~") - ] + input_files = [str(path) for path in Path(args.input[0]).rglob("*") if path.suffix in [".xlsx", ".xlsm"] and not path.name.startswith("~")] print(f"Loading {len(input_files)} files from {args.input[0]}") else: input_files = args.input @@ -433,9 +373,7 @@ def run(args) -> str | None: ) sys.exit(0) - tables = convert_xl_to_times( - input_files, args.output_dir, config, model, args.use_pkl, verbose=args.verbose - ) + tables = convert_xl_to_times(input_files, args.output_dir, config, model, args.use_pkl, verbose=args.verbose) if args.dd: write_dd_files(tables, config, args.output_dir) @@ -471,9 +409,7 @@ def parse_args(arg_list: None | list[str]) -> argparse.Namespace: default="", help="Comma-separated list of regions to include in the model", ) - args_parser.add_argument( - "--output_dir", type=str, default="output", help="Output directory" - ) + args_parser.add_argument("--output_dir", type=str, default="output", help="Output directory") args_parser.add_argument( "--ground_truth_dir", type=str, diff --git a/xl2times/datatypes.py b/xl2times/datatypes.py index 13f2d315..6e42ec70 100644 --- a/xl2times/datatypes.py +++ b/xl2times/datatypes.py @@ -101,9 +101,7 @@ def __eq__(self, o: object) -> bool: and self.dataframe.shape == o.dataframe.shape and ( len(self.dataframe) == 0 # Empty tables don't affect our output - or self.dataframe.sort_index(axis=1).equals( - o.dataframe.sort_index(axis=1) - ) + or self.dataframe.sort_index(axis=1).equals(o.dataframe.sort_index(axis=1)) ) ) @@ -209,9 +207,7 @@ def __init__( self.discard_if_empty, self.known_columns, ) = Config._read_veda_tags_info(veda_tags_file) - self.veda_attr_defaults, self.attr_aliases = Config._read_veda_attr_defaults( - veda_attr_defaults_file - ) + self.veda_attr_defaults, self.attr_aliases = Config._read_veda_attr_defaults(veda_attr_defaults_file) # Migration in progress: use parameter mappings from times_info_file for now name_to_map = {m.times_name: m for m in self.times_xl_maps} for m in param_mappings: @@ -234,16 +230,10 @@ def _process_times_info( unknown_cats = {item["gams-cat"] for item in table_info} - set(categories) if unknown_cats: print(f"WARNING: Unknown categories in times-info.json: {unknown_cats}") - dd_table_order = chain.from_iterable( - (sorted(cat_to_tables[c]) for c in categories) - ) + dd_table_order = chain.from_iterable((sorted(cat_to_tables[c]) for c in categories)) # Compute the set of all attributes, i.e. all entities with category = parameter - attributes = { - item["name"].lower() - for item in table_info - if item["gams-cat"] == "parameter" - } + attributes = {item["name"].lower() for item in table_info if item["gams-cat"] == "parameter"} # Compute the mapping for attributes / parameters: def create_mapping(entity): @@ -252,11 +242,7 @@ def create_mapping(entity): xl_cols = entity["mapping"] + ["value"] # TODO map in json col_map = dict(zip(times_cols, xl_cols)) # If tag starts with UC, then the data is in UCAttributes, else Attributes - xl_name = ( - "UCAttributes" - if entity["name"].lower().startswith("uc") - else "Attributes" - ) + xl_name = "UCAttributes" if entity["name"].lower().startswith("uc") else "Attributes" return TimesXlMap( times_name=entity["name"], times_cols=times_cols, @@ -267,10 +253,7 @@ def create_mapping(entity): ) param_mappings = [ - create_mapping(x) - for x in table_info - if x["gams-cat"] == "parameter" - and "type" not in x # TODO Generalise derived parameters? + create_mapping(x) for x in table_info if x["gams-cat"] == "parameter" and "type" not in x # TODO Generalise derived parameters? ] return dd_table_order, attributes, param_mappings @@ -304,9 +287,7 @@ def _read_mappings(filename: str) -> List[TimesXlMap]: if line == "": break (times, xl) = line.split(" = ") - (times_name, times_cols_str) = list( - filter(None, re.split("\[|\]", times)) - ) + (times_name, times_cols_str) = list(filter(None, re.split("\[|\]", times))) (xl_name, xl_cols_str) = list(filter(None, re.split("\(|\)", xl))) times_cols = times_cols_str.split(",") xl_cols = xl_cols_str.split(",") @@ -318,9 +299,7 @@ def _read_mappings(filename: str) -> List[TimesXlMap]: xl_cols = [s.lower() for s in xl_cols if ":" not in s] # TODO remove: Filter out mappings that are not yet finished - if xl_name != "~TODO" and not any( - c.startswith("TODO") for c in xl_cols - ): + if xl_name != "~TODO" and not any(c.startswith("TODO") for c in xl_cols): col_map = {} assert len(times_cols) <= len(xl_cols) for index, value in enumerate(times_cols): @@ -341,20 +320,13 @@ def _read_mappings(filename: str) -> List[TimesXlMap]: dropped.append(line) if len(dropped) > 0: - print( - f"WARNING: Dropping {len(dropped)} mappings that are not yet complete" - ) + print(f"WARNING: Dropping {len(dropped)} mappings that are not yet complete") return mappings @staticmethod def _read_veda_tags_info( veda_tags_file: str, - ) -> Tuple[ - Dict[Tag, Dict[str, str]], - Dict[Tag, Dict[str, list]], - Iterable[Tag], - Dict[Tag, Set[str]], - ]: + ) -> Tuple[Dict[Tag, Dict[str, str]], Dict[Tag, Dict[str, list]], Iterable[Tag], Dict[Tag, Set[str]],]: def to_tag(s: str) -> Tag: # The file stores the tag name in lowercase, and without the ~ return Tag("~" + s.upper()) @@ -367,9 +339,7 @@ def to_tag(s: str) -> Tag: tags = {to_tag(tag_info["tag_name"]) for tag_info in veda_tags_info} for tag in Tag: if tag not in tags: - print( - f"WARNING: datatypes.Tag has an unknown Tag {tag} not in {veda_tags_file}" - ) + print(f"WARNING: datatypes.Tag has an unknown Tag {tag} not in {veda_tags_file}") valid_column_names = {} row_comment_chars = {} @@ -386,10 +356,7 @@ def to_tag(s: str) -> Tag: # Process column aliases and comment chars: for valid_field in tag_info["valid_fields"]: valid_field_names = valid_field["aliases"] - if ( - "use_name" in valid_field - and valid_field["use_name"] != valid_field["name"] - ): + if "use_name" in valid_field and valid_field["use_name"] != valid_field["name"]: field_name = valid_field["use_name"] valid_field_names.append(valid_field["name"]) else: @@ -399,9 +366,7 @@ def to_tag(s: str) -> Tag: for valid_field_name in valid_field_names: valid_column_names[tag_name][valid_field_name] = field_name - row_comment_chars[tag_name][field_name] = valid_field[ - "row_ignore_symbol" - ] + row_comment_chars[tag_name][field_name] = valid_field["row_ignore_symbol"] # TODO: Account for differences in valid field names with base_tag if "base_tag" in tag_info: @@ -431,9 +396,7 @@ def _read_veda_attr_defaults( "tslvl": {"DAYNITE": [], "ANNUAL": []}, } - attr_aliases = { - attr for attr in defaults if "times-attribute" in defaults[attr] - } + attr_aliases = {attr for attr in defaults if "times-attribute" in defaults[attr]} for attr, attr_info in defaults.items(): # Populate aliases by attribute dictionary diff --git a/xl2times/excel.py b/xl2times/excel.py index 48c85ba6..b5845404 100644 --- a/xl2times/excel.py +++ b/xl2times/excel.py @@ -34,24 +34,16 @@ def extract_tables(filename: str) -> List[datatypes.EmbeddedXlTable]: for colname in df.columns: value = str(row[colname]) if value.startswith("~"): - match = re.match( - f"{datatypes.Tag.uc_sets.value}:(.*)", value, re.IGNORECASE - ) + match = re.match(f"{datatypes.Tag.uc_sets.value}:(.*)", value, re.IGNORECASE) if match: parts = match.group(1).split(":") if len(parts) == 2: uc_sets[parts[0].strip()] = parts[1].strip() else: - print( - f"WARNING: Malformed UC_SET in {sheet.title}, {filename}" - ) + print(f"WARNING: Malformed UC_SET in {sheet.title}, {filename}") else: col_index = df.columns.get_loc(colname) - sheet_tables.append( - extract_table( - row_index, col_index, uc_sets, df, sheet.title, filename - ) - ) + sheet_tables.append(extract_table(row_index, col_index, uc_sets, df, sheet.title, filename)) for sheet_table in sheet_tables: sheet_table.uc_sets = uc_sets @@ -123,9 +115,7 @@ def extract_table( end_col += 1 end_row = header_row - while end_row < df.shape[0] and not are_cells_all_empty( - df, end_row, start_col, end_col - ): + while end_row < df.shape[0] and not are_cells_all_empty(df, end_row, start_col, end_col): end_row += 1 # Excel cell numbering starts at 1, while pandas starts at 0 @@ -190,8 +180,4 @@ def cell_is_empty(value) -> bool: :param value: Cell value. :return: Boolean indicating if the cells are empty. """ - return ( - value is None - or (isinstance(value, numpy.floating) and numpy.isnan(value)) - or (isinstance(value, str) and len(value.strip()) == 0) - ) + return value is None or (isinstance(value, numpy.floating) and numpy.isnan(value)) or (isinstance(value, str) and len(value.strip()) == 0) diff --git a/xl2times/utils.py b/xl2times/utils.py index a69f94b9..7654cf87 100644 --- a/xl2times/utils.py +++ b/xl2times/utils.py @@ -54,9 +54,7 @@ def explode(df, data_columns): column name for each value in each new row. """ data = df[data_columns].values.tolist() - other_columns = [ - colname for colname in df.columns.values if colname not in data_columns - ] + other_columns = [colname for colname in df.columns.values if colname not in data_columns] df = df[other_columns] value_column = "value" df = df.assign(value=data) @@ -110,9 +108,7 @@ def merge_columns(tables: List[datatypes.EmbeddedXlTable], tag: str, colname: st return numpy.concatenate(columns) -def apply_wildcards( - df: DataFrame, candidates: Iterable[str], wildcard_col: str, output_col: str -): +def apply_wildcards(df: DataFrame, candidates: Iterable[str], wildcard_col: str, output_col: str): """ Apply wildcards values to a list of candidates. Wildcards are values containing '*'. For example, a value containing '*SOLID*' would include all the values in 'candidates' containing 'SOLID' in the middle.