From 398b9f769da1290a2b9f527517722dd8a75fda65 Mon Sep 17 00:00:00 2001 From: Bernardo Porto Veronese Date: Thu, 25 Jul 2024 16:51:28 -0300 Subject: [PATCH] Replace old csv solution with fixed pandas DataFrame --- gladeparser/columns.py | 75 ++++++++++--------------- gladeparser/constants.py | 118 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 146 insertions(+), 47 deletions(-) create mode 100644 gladeparser/constants.py diff --git a/gladeparser/columns.py b/gladeparser/columns.py index 4f5b01c..c37ce3e 100644 --- a/gladeparser/columns.py +++ b/gladeparser/columns.py @@ -1,47 +1,16 @@ -import os from enum import Enum from typing import Union, List -import numpy as np import pandas as pd import polars as pl -COLUMNS_FILENAME = "columns.csv" - -DTYPES = { - "ID": int, - "Catalog ID": str, - "Object type flag": str, - "Localization": np.float64, - "Magnitude": np.float64, - "Distance": np.float64, - "Mass": np.float64, - "Merger rate": np.float64, -} - -POLARS_DTYPES = { - "ID": pl.Int32(), - "Catalog ID": pl.String(), - "Object type flag": pl.String(), - "Localization": pl.Float64(), - "Magnitude": pl.Float64(), - "Distance": pl.Float64(), - "Mass": pl.Float64(), - "Merger rate": pl.Float64(), -} - -POLARS_DTYPES_OVERRIDES = { - "B flag": pl.Int8(), - "W1 flag": pl.Int8(), - "z flag": pl.Int8(), - "dist flag": pl.Int8(), - "M* flag": pl.Int8(), -} - - -def get_column_filepath() -> str: - dirname = os.getcwd() - return os.path.join(dirname, COLUMNS_FILENAME) +from .constants import ( + COLUMN_NAMES, + GROUPS, + DTYPES, + POLARS_DTYPES, + POLARS_DTYPES_OVERRIDES, +) class Group(str, Enum): @@ -64,34 +33,46 @@ def values(cls): class GLADEDescriptor: def __init__(self): - column_filepath = get_column_filepath() - self._columns = pd.read_csv(column_filepath, index_col="Column ID") + self.column_names_col = "Column Name" + self.group_col = "Group" + data = {self.column_names_col: COLUMN_NAMES, self.group_col: GROUPS} + self._columns = pd.DataFrame( + data=data, index=list(range(1, len(COLUMN_NAMES) + 1)) + ) def __str__(self): return str(self._columns) + @property + def _column_names(self): + return self._columns[self.column_names_col] + + @property + def _column_groups(self): + return self._columns[self.group_col] + @property def groups(self) -> List[str]: - return list(dict.fromkeys(self._columns["Group"])) + return list(dict.fromkeys(self._columns[self.group_col])) @property def names(self) -> List[str]: - return self._columns["Column Name"].to_list() + return self._column_names.to_list() @property def column_dtypes(self): - dtype_list = self._columns["Group"].map(DTYPES) + dtype_list = self._column_groups.map(DTYPES) return dict(zip(self.names, dtype_list)) @property def polars_schema(self): - dtype_list = self._columns["Group"].map(POLARS_DTYPES) + dtype_list = self._column_groups.map(POLARS_DTYPES) dtype_dict = dict(zip(self.names, dtype_list)) dtype_dict.update(**POLARS_DTYPES_OVERRIDES) return pl.Schema(dtype_dict) def _index_to_name(self, indices: List[int]) -> List[str]: - return self._columns["Column Name"][indices].to_list() + return self._column_names[indices].to_list() def _parse_column(self, column: Union[int, str]) -> List[int]: if isinstance(column, int): @@ -100,9 +81,9 @@ def _parse_column(self, column: Union[int, str]) -> List[int]: raise ValueError("column argument should be int or str") query = ( - f'Group == "{column}"' + f'{self.group_col} == "{column}"' if column in Group.values() - else f'`Column Name` == "{column}"' + else f'`{self.column_names_col}` == "{column}"' ) return self._columns.query(query).index.to_list() diff --git a/gladeparser/constants.py b/gladeparser/constants.py new file mode 100644 index 0000000..086432c --- /dev/null +++ b/gladeparser/constants.py @@ -0,0 +1,118 @@ +from numpy import float64 +from polars import Int8, Int32, Float64, String + +COLUMN_NAMES = [ + "GLADE no", + "PGC no", + "GWGC name", + "HyperLEDA name", + "2MASS name", + "WISExSCOS name", + "SDSS-DR16Q name", + "Object type flag", + "RA", + "Dec", + "B", + "B_err", + "B flag", + "B_Abs", + "J", + "J_err", + "H", + "H_err", + "K", + "K_err", + "W1", + "W1_err", + "W2", + "W2_err", + "W1 flag", + "B_J", + "B_J err", + "z_helio", + "z_cmb", + "z flag", + "v_err", + "z_err", + "d_L", + "d_L err", + "dist flag", + "M*", + "M*_err", + "M* flag", + "Merger rate", + "Merger rate error", +] + +GROUPS = [ + "ID", + "Catalog ID", + "Catalog ID", + "Catalog ID", + "Catalog ID", + "Catalog ID", + "Catalog ID", + "Object type flag", + "Localization", + "Localization", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Magnitude", + "Distance", + "Distance", + "Distance", + "Distance", + "Distance", + "Distance", + "Distance", + "Distance", + "Mass", + "Mass", + "Mass", + "Merger rate", + "Merger rate", +] + +DTYPES = { + "ID": int, + "Catalog ID": str, + "Object type flag": str, + "Localization": float64, + "Magnitude": float64, + "Distance": float64, + "Mass": float64, + "Merger rate": float64, +} + +POLARS_DTYPES = { + "ID": Int32(), + "Catalog ID": String(), + "Object type flag": String(), + "Localization": Float64(), + "Magnitude": Float64(), + "Distance": Float64(), + "Mass": Float64(), + "Merger rate": Float64(), +} + +POLARS_DTYPES_OVERRIDES = { + "B flag": Int8(), + "W1 flag": Int8(), + "z flag": Int8(), + "dist flag": Int8(), + "M* flag": Int8(), +}