Skip to content

Commit

Permalink
concat plugin class to handle typical case. update group by plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
nickzoic committed Jul 11, 2024
1 parent e6faae9 commit c603d2f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 30 deletions.
7 changes: 4 additions & 3 deletions countess/core/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,6 @@ def is_ancestor_of(self, node):
def is_descendant_of(self, node):
return (self in node.child_nodes) or any((self.is_descendant_of(n) for n in node.child_nodes))

def plugin_process(self, x):
self.plugin.process(*x)

def add_output_queue(self):
queue = SentinelQueue(maxsize=3)
self.output_queues.add(queue)
Expand All @@ -124,13 +121,15 @@ def run_multithread(self, queue: SentinelQueue, name: str, logger: Logger, row_l
assert isinstance(self.plugin, ProcessPlugin)
for data_in in queue:
self.counter_in += 1
self.plugin.preprocess(data_in, name, logger)
self.queue_output(self.plugin.process(data_in, name, logger))

def run_subthread(self, queue: SentinelQueue, name: str, logger: Logger, row_limit: Optional[int] = None):
assert isinstance(self.plugin, ProcessPlugin)

for data_in in queue:
self.counter_in += 1
self.plugin.preprocess(data_in, name, logger)
self.queue_output(self.plugin.process(data_in, name, logger))
self.queue_output(self.plugin.finished(name, logger))

Expand Down Expand Up @@ -199,6 +198,8 @@ def prerun(self, logger: Logger, row_limit=PRERUN_ROW_LIMIT):
assert isinstance(self.plugin, ProcessPlugin)
parent_node.prerun(logger, row_limit)
if parent_node.result:
for data_in in parent_node.result:
self.plugin.preprocess(data_in, parent_node.name, logger)
for data_in in parent_node.result:
self.result += list(self.plugin.process(data_in, parent_node.name, logger))
self.result += list(self.plugin.finished(parent_node.name, logger))
Expand Down
45 changes: 43 additions & 2 deletions countess/core/plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,11 @@ class ProcessPlugin(BasePlugin):
def prepare(self, sources: List[str], row_limit: Optional[int] = None):
pass

def preprocess(self, data, source: str, logger: Logger) -> None:
"""Called with each `data` input from `source` before `process` is called
for that data, to set up config etc. Can't return anything."""
pass

def process(self, data, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
"""Called with each `data` input from `source`, yields results"""
raise NotImplementedError(f"{self.__class__}.process")
Expand Down Expand Up @@ -216,10 +221,45 @@ def finalize(self, logger: Logger) -> Iterable:
class PandasProcessPlugin(ProcessPlugin):
DATAFRAME_BUFFER_SIZE = 100000

def preprocess(self, data: pd.DataFrame, source: str, logger: Logger) -> None:
pass

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
raise NotImplementedError(f"{self.__class__}.process")


class PandasConcatProcessPlugin(PandasProcessPlugin):
# Like PandsaProcessPlugin but collect all the inputs together before trying to do anything
# with them.

def __init__(self, *a, **k) -> None:
super().__init__(*a, **k)
self.dataframes: list[pd.DataFrame] = []
self.input_columns: dict[str, np.dtype] = {}

def prepare(self, *_):
self.dataframes = []
self.input_columns = {}

def preprocess(self, data: pd.DataFrame, source: str, logger: Logger) -> None:
self.input_columns.update(get_all_columns(data))

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable:
self.dataframes.append(data)
print(data)
return []

def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
data_in = pd.concat(self.dataframes)
data_out = self.process_dataframe(data_in, logger)
if data_out is not None:
yield data_out

def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> Optional[pd.DataFrame]:
"""Override this to process a single dataframe"""
raise NotImplementedError(f"{self.__class__}.process_dataframe()")


class PandasSimplePlugin(SimplePlugin):
"""Base class for plugins which accept and return pandas DataFrames.
Subclassing this hides all the distracting aspects of the pipeline
Expand All @@ -230,12 +270,13 @@ class PandasSimplePlugin(SimplePlugin):
def prepare(self, sources: list[str], row_limit: Optional[int] = None):
self.input_columns = {}

def preprocess(self, data: pd.DataFrame, source: str, logger: Logger) -> None:
self.input_columns.update(get_all_columns(data))

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable[pd.DataFrame]:
"""Just deal with each dataframe as it comes. PandasSimplePlugins don't care about `source`."""
assert isinstance(data, pd.DataFrame)

self.input_columns.update(get_all_columns(data))

try:
result = self.process_dataframe(data, logger)
if result is not None:
Expand Down
42 changes: 17 additions & 25 deletions countess/plugins/group_by.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,23 @@
from typing import Iterable, List
from typing import Iterable, Optional

import numpy as np
import pandas as pd
from pandas.api.typing import DataFrameGroupBy # type: ignore

from countess import VERSION
from countess.core.logger import Logger
from countess.core.parameters import ArrayParam, BooleanParam, PerColumnArrayParam, TabularMultiParam
from countess.core.plugins import PandasProcessPlugin
from countess.core.plugins import PandasConcatProcessPlugin
from countess.utils.pandas import flatten_columns, get_all_columns


class GroupByPlugin(PandasProcessPlugin):
class GroupByPlugin(PandasConcatProcessPlugin):
"""Groups a Pandas Dataframe by an arbitrary column and rolls up rows"""

name = "Group By"
description = "Group records by column(s) and calculate aggregates"
version = VERSION
link = "https://countess-project.github.io/CountESS/included-plugins/#group-by"

input_columns: dict[str, np.dtype]

parameters = {
"columns": PerColumnArrayParam(
"Columns",
Expand All @@ -39,48 +36,41 @@ class GroupByPlugin(PandasProcessPlugin):
"join": BooleanParam("Join Back?"),
}

dataframes: List[pd.DataFrame]

def __init__(self, *a, **k):
super().__init__(*a, **k)
self.prepare()
self.input_columns = {}

def prepare(self, *_):
self.dataframes = []
self.input_columns = {}

def process(self, data: pd.DataFrame, source: str, logger: Logger) -> Iterable:
# XXX should do this in two stages: group each dataframe and then combine.
# that can wait for a more general MapReduceFinalizePlugin class though.
assert self.dataframes is not None
assert isinstance(self.parameters["columns"], ArrayParam)

self.input_columns.update(get_all_columns(data))
column_parameters = self.parameters["columns"].params

assert isinstance(self.parameters["columns"], ArrayParam)

if not self.parameters["join"].value:
# Dispose of any columns we don't use in the aggregations.
# TODO: Reindex as well?
keep_columns = [
col_param.label
for col_param in column_parameters
for col_param in self.parameters["columns"].params
if isinstance(col_param, TabularMultiParam)
and any(cp.value for cp in col_param.values())
and col_param.label in data.columns
]
data = data[keep_columns]

self.dataframes.append(data)
return []
yield from super().process(data, source, logger)

def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
def process_dataframe(self, dataframe: pd.DataFrame, logger: Logger) -> Optional[pd.DataFrame]:
assert isinstance(self.parameters["columns"], ArrayParam)
assert self.dataframes

print(self.input_columns)
self.parameters["columns"].set_column_choices(self.input_columns.keys())

column_parameters = list(zip(self.input_columns.keys(), self.parameters["columns"]))

index_cols = [col for col, col_param in column_parameters if col_param["index"].value]
agg_ops = dict(
(
Expand All @@ -91,12 +81,13 @@ def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:
if any(pp.value for k, pp in col_param.params.items() if k != "index")
)

data_in = pd.concat(self.dataframes)
data_in.reset_index([col for col in agg_ops.keys() if col in data_in.index.names], inplace=True)
# reset any indexes which are actually columns we want to aggregate
data_in = dataframe.reset_index([col for col in agg_ops.keys() if col in dataframe.index.names])

try:
# If there are no columns to index by, add a dummy column and group by that so we
# still get a DataFrameGroupBy for the next operation
# XXX there's probably an easier way to do this ...
data_grouped: DataFrameGroupBy = (
data_in.groupby(index_cols) if index_cols else data_in.assign(__temp=1).groupby("__temp")
)
Expand All @@ -115,10 +106,11 @@ def finalize(self, logger: Logger) -> Iterable[pd.DataFrame]:

if self.parameters["join"].value:
if index_cols:
yield data_in.merge(data_out, how="left", left_on=index_cols, right_on=index_cols)
return data_in.merge(data_out, how="left", left_on=index_cols, right_on=index_cols)
else:
yield data_in.assign(**data_out.to_dict("records")[0])
return data_in.assign(**data_out.to_dict("records")[0])
else:
yield data_out
return data_out
except (KeyError, ValueError) as exc:
logger.exception(exc)
return None

0 comments on commit c603d2f

Please sign in to comment.