Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aggr-] allow ranking rows by key column #2417

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/aggregators-cols.vdj
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#!vd -p
{"sheet": "global", "col": null, "row": "disp_date_fmt", "longname": "set-option", "input": "%b %d, %Y", "keystrokes": "", "comment": null}
{"longname": "open-file", "input": "sample_data/test.jsonl", "keystrokes": "o"}
{"sheet": "test", "col": "key2", "row": "", "longname": "key-col", "input": "", "keystrokes": "!", "comment": "toggle current column as a key column"}
{"sheet": "test", "col": "key2", "row": "", "longname": "addcol-aggregate", "input": "count", "comment": "add column(s) with aggregator of rows grouped by key columns"}
{"sheet": "test", "col": "qty", "row": "", "longname": "type-float", "input": "", "keystrokes": "%", "comment": "set type of current column to float"}
{"sheet": "test", "col": "qty", "row": "", "longname": "addcol-aggregate", "input": "rank sum", "comment": "add column(s) with aggregator of rows grouped by key columns"}
{"sheet": "test", "col": "qty_sum", "row": "", "longname": "addcol-sheetrank", "input": "", "comment": "add column with the rank of each row based on its key columns"}
11 changes: 11 additions & 0 deletions tests/golden/aggregators-cols.tsv
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
key2 key2_count key1 qty qty_rank qty_sum test_sheetrank amt
foo 2 2016-01-01 11:00:00 1.00 1 31.00 5
0 2016-01-01 1:00 2.00 1 66.00 2 3
baz 3 4.00 1 292.00 4 43.2
#ERR 0 #ERR #ERR 1 0.00 1 #ERR #ERR
bar 2 2017-12-25 8:44 16.00 2 16.00 3 .3
baz 3 32.00 2 292.00 4 3.3
0 2018-07-27 4:44 64.00 2 66.00 2 9.1
bar 2 2018-07-27 16:44 1 16.00 3
baz 3 2018-07-27 18:44 256.00 3 292.00 4 .01
foo 2 2018-10-20 18:44 30.00 2 31.00 5 .01
126 changes: 116 additions & 10 deletions visidata/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import functools
import collections
import statistics
import itertools

from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData
from visidata import vd, anytype, vlen, asyncthread, wrapply, AttrDict, date, INPROGRESS
from visidata import Progress, Sheet, Column, ColumnsSheet, VisiData, SettableColumn
from visidata import vd, anytype, vlen, asyncthread, wrapply, AttrDict, date, INPROGRESS, stacktrace, TypedExceptionWrapper

vd.help_aggregators = '''# Choose Aggregators
Start typing an aggregator name or description.
Expand Down Expand Up @@ -76,7 +77,7 @@ def aggregators_set(col, aggs):


class Aggregator:
def __init__(self, name, type, funcValues=None, helpstr='foo'):
def __init__(self, name, type, funcValues=None, helpstr=''):
'Define aggregator `name` that calls funcValues(values)'
self.type = type
self.funcValues = funcValues # funcValues(values)
Expand All @@ -92,13 +93,48 @@ def aggregate(self, col, rows): # wrap builtins so they can have a .type
return None
raise e

class ListAggregator(Aggregator):
'''A list aggregator is an aggregator that returns a list of values, generally
one value per input row, unlike ordinary aggregators that operate on rows
and return only a single value.
To implement a new list aggregator, subclass ListAggregator,
and override aggregate() and aggregate_list().'''
def __init__(self, name, type, helpstr='', listtype=None):
'''*listtype* determines the type of the column created by addcol_aggregate()
for list aggrs. If it is None, then the new column will match the type of the input column'''
super().__init__(name, type, helpstr=helpstr)
self.listtype = listtype

def aggregate(self, col, rows) -> list:
'''Return a list, which can be shorter than *rows*, because it filters out nulls and errors.
Override in subclass.'''
vals = self.aggregate_list(col, rows)
# filter out nulls and errors
vals = [ v for v in vals if not col.sheet.isNullFunc()(v) ]
return vals

def aggregate_list(self, col, row_group) -> list:
'''Return a list of results, which will be one result per input row.
*row_group* is an iterable that holds a "group" of rows to run the aggregator on.
rows in *row_group* are not necessarily in the same order they are in the sheet.
Override in subclass.'''
vals = [ col.getTypedValue(r) for r in row_group ]
return vals

@VisiData.api
def aggregator(vd, name, funcValues, helpstr='', *, type=None):
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
Use *type* to force type of aggregated column (default to use type of source column).'''
vd.aggregators[name] = Aggregator(name, type, funcValues=funcValues, helpstr=helpstr)

@VisiData.api
def aggregator_list(vd, name, helpstr='', type=anytype, listtype=anytype):
'''Define simple aggregator *name* that calls ``funcValues(values)`` to aggregate *values*.
Use *type* to force type of aggregated column (default to use type of source column).
Use *listtype* to force the type of the new column created by addcol-aggregate.
If *listtype* is None, it will match the type of the source column.'''
vd.aggregators[name] = ListAggregator(name, type, helpstr=helpstr, listtype=listtype)

## specific aggregator implementations

def mean(vals):
Expand All @@ -109,6 +145,13 @@ def mean(vals):
def vsum(vals):
return sum(vals, start=type(vals[0] if len(vals) else 0)()) #1996

def stdev(vals):
try:
return statistics.stdev(vals)
except statistics.StatisticsError as e: #when vals holds only 1 element
e.stacktrace = stacktrace()
return TypedExceptionWrapper(None, exception=e)

# http://code.activestate.com/recipes/511478-finding-the-percentile-of-the-values/
def _percentile(N, percent, key=lambda x:x):
"""
Expand Down Expand Up @@ -140,10 +183,49 @@ def __init__(self, pct, helpstr=''):
def aggregate(self, col, rows):
return _percentile(sorted(col.getValues(rows)), self.pct/100, key=float)


def quantiles(q, helpstr):
return [PercentileAggregator(round(100*i/q), helpstr) for i in range(1, q)]

def aggregate_groups(sheet, col, rows, aggr) -> list:
'''Returns a list, containing the result of the aggregator applied to each row.
*col* is a column whose values determine each row's rank within a group.
*rows* is a list of visidata rows.
*aggr* is an Aggregator object.
Rows are grouped by their key columns. Null key column cells are considered equal,
so nulls are grouped together. Cells with exceptions do not group together.
Each exception cell is grouped by itself, with only one row in the group.
'''
def _key_progress(prog):
def identity(val):
prog.addProgress(1)
return val
return identity

with Progress(gerund='ranking', total=4*sheet.nRows) as prog:
p = _key_progress(prog) # increment progress every time p() is called
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
rowdata = [(sheet.rowkey(r), col.getTypedValue(r), p(rownum)) for rownum, r in enumerate(rows)]
# sort by row key and column value to prepare for grouping
try:
rowdata.sort(key=p)
except TypeError as e:
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
rowvals = []
#group by row key
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
# within a group, the rows have already been sorted by col_val
group = list(group)
if isinstance(aggr, ListAggregator): # for list aggregators, each row gets its own value
aggr_vals = aggr.aggregate_list(col, [rows[rownum] for _, _, rownum in group])
rowvals += [(rownum, v) for (_, _, rownum), v in zip(group, aggr_vals)]
else: # for normal aggregators, each row in the group gets the same value
aggr_val = aggr.aggregate(col, [rows[rownum] for _, _, rownum in group])
rowvals += [(rownum, aggr_val) for _, _, rownum in group]
prog.addProgress(len(group))
# sort by unique rownum, to make rank results match the original row order
rowvals.sort(key=p)
rowvals = [ v for rownum, v in rowvals ]
return rowvals

vd.aggregator('min', min, 'minimum value')
vd.aggregator('max', max, 'maximum value')
Expand All @@ -154,7 +236,7 @@ def quantiles(q, helpstr):
vd.aggregator('sum', vsum, 'sum of values')
vd.aggregator('distinct', set, 'distinct values', type=vlen)
vd.aggregator('count', lambda values: sum(1 for v in values), 'number of values', type=int)
vd.aggregator('list', list, 'list of values', type=anytype)
vd.aggregator_list('list', 'list of values', type=anytype, listtype=None)
vd.aggregator('stdev', statistics.stdev, 'standard deviation of values', type=float)

vd.aggregators['q3'] = quantiles(3, 'tertiles (33/66th pctile)')
Expand Down Expand Up @@ -243,7 +325,8 @@ def memo_aggregate(col, agg_choices, rows):
for agg in aggs:
aggval = agg.aggregate(col, rows)
typedval = wrapply(agg.type or col.type, aggval)
dispval = col.format(typedval)
# limit width to limit formatting time when typedval is a long list
dispval = col.format(typedval, width=1000)
k = col.name+'_'+agg.name
vd.status(f'{k}={dispval}')
vd.memory[k] = typedval
Expand All @@ -254,14 +337,13 @@ def aggregator_choices(vd):
return [
AttrDict(key=agg, desc=v[0].helpstr if isinstance(v, list) else v.helpstr)
for agg, v in vd.aggregators.items()
if not agg.startswith('p') # skip all the percentiles, user should use q# instead
if not (agg.startswith('p') and agg[1:].isdigit()) # skip all the percentiles like 'p10', user should use q# instead
]


@VisiData.api
def chooseAggregators(vd):
def chooseAggregators(vd, prompt = 'choose aggregators: '):
'''Return a list of aggregator name strings chosen or entered by the user. User-entered names may be invalid.'''
prompt = 'choose aggregators: '
def _fmt_aggr_summary(match, row, trigger_key):
formatted_aggrname = match.formatted.get('key', row.key) if match else row.key
r = ' '*(len(prompt)-3)
Expand All @@ -288,10 +370,34 @@ def _fmt_aggr_summary(match, row, trigger_key):
vd.warning(f'aggregator does not exist: {aggr}')
return aggrs

Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'add aggregator to current column')
@Sheet.api
@asyncthread
def addcol_aggregate(sheet, col, aggrnames):
for aggrname in aggrnames:
aggrs = vd.aggregators.get(aggrname)
aggrs = aggrs if isinstance(aggrs, list) else [aggrs]
if not aggrs: continue
for aggr in aggrs:
rows = aggregate_groups(sheet, col, sheet.rows, aggr)
if isinstance(aggr, ListAggregator):
t = aggr.listtype or col.type
else:
t = aggr.type or col.type
c = SettableColumn(name=f'{col.name}_{aggr.name}', type=t)
sheet.addColumnAtCursor(c)
c.setValues(sheet.rows, *rows)

Sheet.addCommand('+', 'aggregate-col', 'addAggregators([cursorCol], chooseAggregators())', 'Add aggregator to current column')
Sheet.addCommand('z+', 'memo-aggregate', 'cursorCol.memo_aggregate(chooseAggregators(), selectedRows or rows)', 'memo result of aggregator over values in selected rows for current column')
ColumnsSheet.addCommand('g+', 'aggregate-cols', 'addAggregators(selectedRows or source[0].nonKeyVisibleCols, chooseAggregators())', 'add aggregators to selected source columns')
Sheet.addCommand('', 'addcol-aggregate', 'addcol_aggregate(cursorCol, chooseAggregators(prompt="aggregator for groups: "))', 'add column(s) with aggregator of rows grouped by key columns')

vd.addGlobals(
ListAggregator=ListAggregator
)

vd.addMenuItems('''
Column > Add aggregator > aggregate-col
Column > Add column > aggregate > addcol-aggregate
''')

74 changes: 74 additions & 0 deletions visidata/features/rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import itertools

from visidata import Sheet, ListAggregator, SettableColumn
from visidata import vd, anytype, asyncthread

class RankAggregator(ListAggregator):
'''
Ranks start at 1, and each group's rank is 1 higher than the previous group.
When elements are tied in ranking, each of them gets the same rank.
'''
def aggregate(self, col, rows) -> [int]:
return self.aggregate_list(col, rows)

def aggregate_list(self, col, rows) -> [int]:
if not col.sheet.keyCols:
vd.error('ranking requires one or more key columns')
return None
return self.rank(col, rows)

def rank(self, col, rows):
# compile row data, for each row a list of tuples: (group_key, rank_key, rownum)
rowdata = [(col.sheet.rowkey(r), col.getTypedValue(r), rownum) for rownum, r in enumerate(rows)]
# sort by row key and column value to prepare for grouping
try:
rowdata.sort()
except TypeError as e:
vd.fail(f'elements in a ranking column must be comparable: {e.args[0]}')
rowvals = []
#group by row key
for _, group in itertools.groupby(rowdata, key=lambda v: v[0]):
# within a group, the rows have already been sorted by col_val
group = list(group)
# rank each group individually
group_ranks = rank_sorted_iterable([col_val for _, col_val, rownum in group])
rowvals += [(rownum, rank) for (_, _, rownum), rank in zip(group, group_ranks)]
# sort by unique rownum, to make rank results match the original row order
rowvals.sort()
rowvals = [ rank for rownum, rank in rowvals ]
return rowvals

vd.aggregators['rank'] = RankAggregator('rank', anytype, helpstr='list of ranks, when grouping by key columns', listtype=int)

def rank_sorted_iterable(vals_sorted) -> [int]:
'''*vals_sorted* is an iterable whose elements form one group.
The iterable must already be sorted.'''

ranks = []
val_groups = itertools.groupby(vals_sorted)
for rank, (_, val_group) in enumerate(val_groups, 1):
for _ in val_group:
ranks.append(rank)
return ranks

@Sheet.api
@asyncthread
def addcol_sheetrank(sheet, rows):
'''
Each row is ranked within its sheet. Rows are ordered by the
value of their key columns.
'''
colname = f'{sheet.name}_sheetrank'
c = SettableColumn(name=colname, type=int)
sheet.addColumnAtCursor(c)
if not sheet.keyCols:
vd.error('ranking requires one or more key columns')
return None
rowkeys = [(sheet.rowkey(r), rownum) for rownum, r in enumerate(rows)]
rowkeys.sort()
ranks = rank_sorted_iterable([rowkey for rowkey, rownum in rowkeys])
row_ranks = sorted(zip((rownum for _, rownum in rowkeys), ranks))
row_ranks = [rank for rownum, rank in row_ranks]
c.setValues(sheet.rows, *row_ranks)

Sheet.addCommand('', 'addcol-sheetrank', 'sheet.addcol_sheetrank(rows)', 'add column with the rank of each row based on its key columns')
1 change: 1 addition & 0 deletions visidata/tests/test_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def isTestableCommand(longname, cmdlist):
'sheet': '',
'col': 'Units',
'row': '5',
'addcol-aggregate': 'max',
}

@pytest.mark.usefixtures('curses_setup')
Expand Down
Loading