Skip to content

Commit

Permalink
Aggregation functions (COUNT, SUM, MIN, MAX, AVG) (#45)
Browse files Browse the repository at this point in the history
* Adds support for multigraphs

* Refactors `_is_edge_attr_match`

* Filters relations by __label__ during `_lookup`

* Bundles relation attributes together for lookup

* Refactors and adds inline docs

* Adds tests for multigraph support

* Cleans up inline docs

* Removes slicing list twice to avoid two copies in memory

* Supports WHERE clause for relationships in multigraphs

* Adds test for multigraph with WHERE clause on single edge

* Accounts for WHERE with string node attributes in MultiDiGraphs

* Unifies all unit tests to work with both DiGraphs and MultiDiGraphs

* Completes multidigraph test for WHERE on node attribute

* Supports logical OR for relationship matching

* Adds tests for logical OR in MATCH for relationships

* Implements aggregation functions

* Removes unused code

* Adds agg function results to `_return_requests`

* Handles `None` values appropriately for MIN and MAX

* Adds tests for agg functions and adjusts existing tests to new output

* Adds examples page

* Adds test for multiple agg functions

* Removes commented code
  • Loading branch information
jackboyla authored Jun 10, 2024
1 parent f6f4beb commit f1ca6f7
Show file tree
Hide file tree
Showing 4 changed files with 296 additions and 21 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ RETURN A.club, B.club
""")
```

See [examples.md](examples.md) for more!

### Example Usage with SQL

Create your own "Sqlite for Neo4j"! This example uses [grand-graph](https://github.com/aplbrain/grand) to run queries in SQL:
Expand Down Expand Up @@ -81,6 +83,7 @@ RETURN
| Graph mutations (e.g. `DELETE`, `SET`,...) | 🛣 | |
| `DISTINCT` | ✅ Thanks @jackboyla! | |
| `ORDER BY` | ✅ Thanks @jackboyla! | |
| Aggregation functions (`COUNT`, `SUM`, `MIN`, `MAX`, `AVG`) | ✅ Thanks @jackboyla! | |

| | | |
| -------------- | -------------- | ---------------- |
Expand Down
66 changes: 66 additions & 0 deletions examples.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@

## Multigraph

```python
from grandcypher import GrandCypher
import networkx as nx

host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June")
host.add_edge("b", "a", __labels__={"paid"}, amount=6)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"friends"}, years=9)
host.add_edge("a", "b", __labels__={"paid"}, amount=40)

qry = """
MATCH (n)-[r:paid]->(m)
RETURN n.name, m.name, r.amount
"""
res = GrandCypher(host).run(qry)
print(res)

'''
{
'n.name': ['Alice', 'Bob'],
'm.name': ['Bob', 'Alice'],
'r.amount': [{(0, 'paid'): 12, (1, 'friends'): None, (2, 'paid'): 40}, {(0, 'paid'): 6, (1, 'paid'): None}]
}
'''
```

## Aggregation Functions

```python
from grandcypher import GrandCypher
import networkx as nx

host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"paid"}, amount=12, date="12th June")
host.add_edge("b", "a", __labels__={"paid"}, amount=6)
host.add_edge("b", "a", __labels__={"paid"}, value=14)
host.add_edge("a", "b", __labels__={"friends"}, years=9)
host.add_edge("a", "b", __labels__={"paid"}, amount=40)

qry = """
MATCH (n)-[r:paid]->(m)
RETURN n.name, m.name, SUM(r.amount)
"""
res = GrandCypher(host).run(qry)
print(res)

'''
{
'n.name': ['Alice', 'Bob'],
'm.name': ['Bob', 'Alice'],
'SUM(r.amount)': [{'paid': 52, 'friends': 0}, {'paid': 6}]
}
'''
```




92 changes: 86 additions & 6 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,13 @@
return_clause : "return"i distinct_return? entity_id ("," entity_id)*
return_clause : "return"i distinct_return? return_item ("," return_item)*
return_item : entity_id | aggregation_function | entity_id "." attribute_id
aggregation_function : AGGREGATE_FUNC "(" entity_id ( "." attribute_id )? ")"
AGGREGATE_FUNC : "COUNT" | "SUM" | "AVG" | "MAX" | "MIN"
attribute_id : CNAME
distinct_return : "DISTINCT"i
limit_clause : "limit"i NUMBER
Expand Down Expand Up @@ -282,6 +288,7 @@ def _get_entity_from_host(
edge_data = host.get_edge_data(*entity_name)
if not edge_data:
return None # print(f"Nothing found for {entity_name} {entity_attribute}")

if entity_attribute:
# looking for edge attribute:
if isinstance(host, nx.MultiDiGraph):
Expand Down Expand Up @@ -376,6 +383,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._matche_paths = None
self._return_requests = []
self._return_edges = {}
self._aggregate_functions = []
self._distinct = False
self._order_by = None
self._order_by_attributes = set()
Expand Down Expand Up @@ -483,9 +491,10 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[i] = v.get(entity_attribute, None)
r_attr[(i, list(v.get('__labels__'))[0])] = v.get(entity_attribute, None)
# eg, [{(0, 'paid'): 70, (1, 'paid'): 90}, {(0, 'paid'): 400, (1, 'friend'): None, (2, 'paid'): 650}]
ret_with_attr.append(r_attr)

ret = ret_with_attr

result[data_path] = list(ret)[offset_limit]
Expand All @@ -497,9 +506,19 @@ def return_clause(self, clause):
# collect all entity identifiers to be returned
for item in clause:
if item:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)
item = item.children[0] if isinstance(item, Tree) else item
if isinstance(item, Tree) and item.data == "aggregation_function":
func = str(item.children[0].value) # AGGREGATE_FUNC
entity = str(item.children[1].value)
if len(item.children) > 2:
entity += "." + str(item.children[2].children[0].value)
self._aggregate_functions.append((func, entity))
self._return_requests.append(entity)
else:
if not isinstance(item, str):
item = str(item.value)
self._return_requests.append(item)


def order_clause(self, order_clause):
self._order_by = []
Expand All @@ -525,12 +544,73 @@ def skip_clause(self, skip):
skip = int(skip[-1])
self._skip = skip


def aggregate(self, func, results, entity, group_keys):
# Collect data based on group keys
grouped_data = {}
for i in range(len(results[entity])):
group_tuple = tuple(results[key][i] for key in group_keys if key in results)
if group_tuple not in grouped_data:
grouped_data[group_tuple] = []
grouped_data[group_tuple].append(results[entity][i])

def _collate_data(data, unique_labels, func):
# for ["COUNT", "SUM", "AVG"], we treat None as 0
if func in ["COUNT", "SUM", "AVG"]:
collated_data = {
label: [(v or 0) for rel in data for k, v in rel.items() if k[1] == label] for label in unique_labels
}
# for ["MAX", "MIN"], we treat None as non-existent
elif func in ["MAX", "MIN"]:
collated_data = {
label: [v for rel in data for k, v in rel.items() if (k[1] == label and v is not None)] for label in unique_labels
}

return collated_data

# Apply aggregation function
aggregate_results = {}
for group, data in grouped_data.items():
# data => [{(0, 'paid'): 70, (1, 'paid'): 90}]
unique_labels = set([k[1] for rel in data for k in rel.keys()])
collated_data = _collate_data(data, unique_labels, func)
if func == "COUNT":
count_data = {label: len(data) for label, data in collated_data.items()}
aggregate_results[group] = count_data
elif func == "SUM":
sum_data = {label: sum(data) for label, data in collated_data.items()}
aggregate_results[group] = sum_data
elif func == "AVG":
sum_data = {label: sum(data) for label, data in collated_data.items()}
count_data = {label: len(data) for label, data in collated_data.items()}
avg_data = {label: sum_data[label] / count_data[label] if count_data[label] > 0 else 0 for label in sum_data}
aggregate_results[group] = avg_data
elif func == "MAX":
max_data = {label: max(data) for label, data in collated_data.items()}
aggregate_results[group] = max_data
elif func == "MIN":
min_data = {label: min(data) for label, data in collated_data.items()}
aggregate_results[group] = min_data

aggregate_results = [v for v in aggregate_results.values()]
return aggregate_results

def returns(self, ignore_limit=False):

results = self._lookup(
self._return_requests + list(self._order_by_attributes),
offset_limit=slice(0, None),
)
if len(self._aggregate_functions) > 0:
group_keys = [key for key in results.keys() if not any(key.endswith(func[1]) for func in self._aggregate_functions)]

aggregated_results = {}
for func, entity in self._aggregate_functions:
aggregated_data = self.aggregate(func, results, entity, group_keys)
func_key = f"{func}({entity})"
aggregated_results[func_key] = aggregated_data
self._return_requests.append(func_key)
results.update(aggregated_results)
if self._order_by:
results = self._apply_order_by(results)
if self._distinct:
Expand Down
Loading

0 comments on commit f1ca6f7

Please sign in to comment.