Skip to content

Commit

Permalink
Replace old csv solution with fixed pandas DataFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
Bernardo Porto Veronese committed Jul 25, 2024
1 parent d8338a7 commit 398b9f7
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 47 deletions.
75 changes: 28 additions & 47 deletions gladeparser/columns.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,16 @@
import os
from enum import Enum
from typing import Union, List

import numpy as np
import pandas as pd
import polars as pl

COLUMNS_FILENAME = "columns.csv"

DTYPES = {
"ID": int,
"Catalog ID": str,
"Object type flag": str,
"Localization": np.float64,
"Magnitude": np.float64,
"Distance": np.float64,
"Mass": np.float64,
"Merger rate": np.float64,
}

POLARS_DTYPES = {
"ID": pl.Int32(),
"Catalog ID": pl.String(),
"Object type flag": pl.String(),
"Localization": pl.Float64(),
"Magnitude": pl.Float64(),
"Distance": pl.Float64(),
"Mass": pl.Float64(),
"Merger rate": pl.Float64(),
}

POLARS_DTYPES_OVERRIDES = {
"B flag": pl.Int8(),
"W1 flag": pl.Int8(),
"z flag": pl.Int8(),
"dist flag": pl.Int8(),
"M* flag": pl.Int8(),
}


def get_column_filepath() -> str:
dirname = os.getcwd()
return os.path.join(dirname, COLUMNS_FILENAME)
from .constants import (
COLUMN_NAMES,
GROUPS,
DTYPES,
POLARS_DTYPES,
POLARS_DTYPES_OVERRIDES,
)


class Group(str, Enum):
Expand All @@ -64,34 +33,46 @@ def values(cls):

class GLADEDescriptor:
def __init__(self):
column_filepath = get_column_filepath()
self._columns = pd.read_csv(column_filepath, index_col="Column ID")
self.column_names_col = "Column Name"
self.group_col = "Group"
data = {self.column_names_col: COLUMN_NAMES, self.group_col: GROUPS}
self._columns = pd.DataFrame(
data=data, index=list(range(1, len(COLUMN_NAMES) + 1))
)

def __str__(self):
return str(self._columns)

@property
def _column_names(self):
return self._columns[self.column_names_col]

@property
def _column_groups(self):
return self._columns[self.group_col]

@property
def groups(self) -> List[str]:
return list(dict.fromkeys(self._columns["Group"]))
return list(dict.fromkeys(self._columns[self.group_col]))

@property
def names(self) -> List[str]:
return self._columns["Column Name"].to_list()
return self._column_names.to_list()

@property
def column_dtypes(self):
dtype_list = self._columns["Group"].map(DTYPES)
dtype_list = self._column_groups.map(DTYPES)
return dict(zip(self.names, dtype_list))

@property
def polars_schema(self):
dtype_list = self._columns["Group"].map(POLARS_DTYPES)
dtype_list = self._column_groups.map(POLARS_DTYPES)
dtype_dict = dict(zip(self.names, dtype_list))
dtype_dict.update(**POLARS_DTYPES_OVERRIDES)
return pl.Schema(dtype_dict)

def _index_to_name(self, indices: List[int]) -> List[str]:
return self._columns["Column Name"][indices].to_list()
return self._column_names[indices].to_list()

def _parse_column(self, column: Union[int, str]) -> List[int]:
if isinstance(column, int):
Expand All @@ -100,9 +81,9 @@ def _parse_column(self, column: Union[int, str]) -> List[int]:
raise ValueError("column argument should be int or str")

query = (
f'Group == "{column}"'
f'{self.group_col} == "{column}"'
if column in Group.values()
else f'`Column Name` == "{column}"'
else f'`{self.column_names_col}` == "{column}"'
)
return self._columns.query(query).index.to_list()

Expand Down
118 changes: 118 additions & 0 deletions gladeparser/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from numpy import float64
from polars import Int8, Int32, Float64, String

COLUMN_NAMES = [
"GLADE no",
"PGC no",
"GWGC name",
"HyperLEDA name",
"2MASS name",
"WISExSCOS name",
"SDSS-DR16Q name",
"Object type flag",
"RA",
"Dec",
"B",
"B_err",
"B flag",
"B_Abs",
"J",
"J_err",
"H",
"H_err",
"K",
"K_err",
"W1",
"W1_err",
"W2",
"W2_err",
"W1 flag",
"B_J",
"B_J err",
"z_helio",
"z_cmb",
"z flag",
"v_err",
"z_err",
"d_L",
"d_L err",
"dist flag",
"M*",
"M*_err",
"M* flag",
"Merger rate",
"Merger rate error",
]

GROUPS = [
"ID",
"Catalog ID",
"Catalog ID",
"Catalog ID",
"Catalog ID",
"Catalog ID",
"Catalog ID",
"Object type flag",
"Localization",
"Localization",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Magnitude",
"Distance",
"Distance",
"Distance",
"Distance",
"Distance",
"Distance",
"Distance",
"Distance",
"Mass",
"Mass",
"Mass",
"Merger rate",
"Merger rate",
]

DTYPES = {
"ID": int,
"Catalog ID": str,
"Object type flag": str,
"Localization": float64,
"Magnitude": float64,
"Distance": float64,
"Mass": float64,
"Merger rate": float64,
}

POLARS_DTYPES = {
"ID": Int32(),
"Catalog ID": String(),
"Object type flag": String(),
"Localization": Float64(),
"Magnitude": Float64(),
"Distance": Float64(),
"Mass": Float64(),
"Merger rate": Float64(),
}

POLARS_DTYPES_OVERRIDES = {
"B flag": Int8(),
"W1 flag": Int8(),
"z flag": Int8(),
"dist flag": Int8(),
"M* flag": Int8(),
}

0 comments on commit 398b9f7

Please sign in to comment.