Skip to content

Commit

Permalink
Adds support for multigraphs (#42)
Browse files Browse the repository at this point in the history
* Adds support for multigraphs

* Refactors `_is_edge_attr_match`

* Filters relations by __label__ during `_lookup`

* Bundles relation attributes together for lookup

* Refactors and adds inline docs

* Adds tests for multigraph support

* Cleans up inline docs

* Removes slicing list twice to avoid two copies in memory

* Supports WHERE clause for relationships in multigraphs

* Adds test for multigraph with WHERE clause on single edge
  • Loading branch information
jackboyla authored May 14, 2024
1 parent a6d0587 commit 78173a3
Show file tree
Hide file tree
Showing 2 changed files with 217 additions and 32 deletions.
126 changes: 94 additions & 32 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"""

from typing import Dict, List, Callable, Tuple
from typing import Dict, List, Callable, Tuple, Union
from collections import OrderedDict
import random
import string
Expand Down Expand Up @@ -119,7 +119,6 @@
| LEFT_ANGLE? "-[" CNAME ":" TYPE "*" MIN_HOP ".." MAX_HOP "]-" RIGHT_ANGLE?
LEFT_ANGLE : "<"
RIGHT_ANGLE : ">"
EQUAL : "="
Expand Down Expand Up @@ -198,14 +197,16 @@ def _is_node_attr_match(

@lru_cache()
def _is_edge_attr_match(
motif_edge_id: Tuple[str, str],
host_edge_id: Tuple[str, str],
motif: nx.Graph,
host: nx.Graph,
motif_edge_id: Tuple[str, str, Union[int, str]],
host_edge_id: Tuple[str, str, Union[int, str]],
motif: Union[nx.Graph, nx.MultiDiGraph],
host: Union[nx.Graph, nx.MultiDiGraph]
) -> bool:
"""
Check if an edge in the host graph matches the attributes in the motif.
This also check the __labels__ of edges.
Check if an edge in the host graph matches the attributes in the motif,
including the special '__labels__' set attribute.
This function formats edges into
nx.MultiDiGraph format i.e {0: first_relation, 1: ...}.
Arguments:
motif_edge_id (str): The motif edge ID
Expand All @@ -215,23 +216,50 @@ def _is_edge_attr_match(
Returns:
bool: True if the host edge matches the attributes in the motif
"""
motif_edge = motif.edges[motif_edge_id]
host_edge = host.edges[host_edge_id]
motif_u, motif_v = motif_edge_id
host_u, host_v = host_edge_id

# Format edges for both DiGraph and MultiDiGraph
motif_edges = _get_edge_attributes(motif, motif_u, motif_v)
host_edges = _get_edge_attributes(host, host_u, host_v)

for attr, val in motif_edge.items():
# Aggregate all __labels__ into one set
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)

for attr, val in motif_edges.items():
if attr == "__labels__":
if val and val - host_edge.get("__labels__", set()):
if val and val - host_edges.get("__labels__", set()):
return False
continue
if host_edge.get(attr) != val:
if host_edges.get(attr) != val:
return False

return True


def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict:
"""
Retrieve edge attributes from a graph, handling both Graph and MultiDiGraph.
"""
if isinstance(graph, nx.MultiDiGraph):
return graph[u][v]
return {0: graph[u][v]} # Mock single edge for DiGraph

def _aggregate_edge_labels(edges: Dict) -> Dict:
"""
Aggregate '__labels__' attributes from edges into a single set.
"""
aggregated = {"__labels__": set()}
for edge_id, attrs in edges.items():
if "__labels__" in attrs and attrs["__labels__"]:
aggregated["__labels__"].update(attrs["__labels__"])
elif "__labels__" not in attrs:
aggregated[edge_id] = attrs
return aggregated

def _get_entity_from_host(host: Union[nx.DiGraph, nx.MultiDiGraph], entity_name, entity_attribute=None):
if entity_name in host.nodes():
# We are looking for a node mapping in the target graph:
if entity_attribute:
Expand All @@ -248,7 +276,10 @@ def _get_entity_from_host(host: nx.DiGraph, entity_name, entity_attribute=None):
return None # print(f"Nothing found for {entity_name} {entity_attribute}")
if entity_attribute:
# looking for edge attribute:
return edge_data.get(entity_attribute, None)
if isinstance(host, nx.MultiDiGraph):
return [r.get(entity_attribute, None) for r in edge_data.values()]
else:
return edge_data.get(entity_attribute, None)
else:
return host.get_edge_data(*entity_name)

Expand Down Expand Up @@ -279,7 +310,7 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:


def cond_(should_be, entity_id, operator, value) -> CONDITION:
def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
def inner(match: dict, host: Union[nx.DiGraph, nx.MultiDiGraph], return_endges: list) -> bool:
host_entity_id = entity_id.split(".")
if host_entity_id[0] in match:
host_entity_id[0] = match[host_entity_id[0]]
Expand All @@ -290,7 +321,13 @@ def inner(match: dict, host: nx.DiGraph, return_endges: list) -> bool:
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")
try:
val = operator(_get_entity_from_host(host, *host_entity_id), value)
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
val = any(operator(r_val, value) for r_val in r_vals)
else:
val = operator(_get_entity_from_host(host, *host_entity_id), value)

except:
val = False
if val != should_be:
Expand Down Expand Up @@ -323,7 +360,7 @@ def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
self._paths = []
self._where_condition: CONDITION = None
self._motif = nx.DiGraph()
self._motif = nx.MultiDiGraph()
self._matches = None
self._matche_paths = None
self._return_requests = []
Expand Down Expand Up @@ -383,9 +420,9 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret.append(path)

else:
mapping_u, mapping_v = self._return_edges[data_path]
mapping_u, mapping_v = self._return_edges[data_path.split('.')[0]]
# We are looking for an edge mapping in the target graph:
is_hop = self._motif.edges[(mapping_u, mapping_v)]["__is_hop__"]
is_hop = self._motif.edges[(mapping_u, mapping_v, 0)]["__is_hop__"]
ret = (
_get_edge(
self._target_graph, mapping, match_path, mapping_u, mapping_v
Expand All @@ -395,13 +432,38 @@ def _lookup(self, data_paths: List[str], offset_limit) -> Dict[str, List]:
ret = (r[0] if is_hop else r for r in ret)
# we keep the original list if len > 2 (edge hop 2+)

# Get all edge labels from the motif -- this is used to filter the relations for multigraphs
motif_edge_labels = set()
for edge in self._motif.get_edge_data(mapping_u, mapping_v).values():
if edge.get('__labels__', None):
motif_edge_labels.update(edge['__labels__'])

if entity_attribute:
# Get the correct entity from the target host graph,
# and then return the attribute:
ret = (r.get(entity_attribute, None) for r in ret)
if isinstance(self._motif, nx.MultiDiGraph) and len(motif_edge_labels) > 0:
# filter the retrieved edge(s) based on the motif edge labels
filtered_ret = []
for r in ret:

if any([i.get('__labels__', None).issubset(motif_edge_labels) for i in r.values()]):
filtered_ret.append(r)

ret = filtered_ret

# get the attribute from the retrieved edge(s)
ret_with_attr = []
for r in ret:
r_attr = {}
for i, v in r.items():
r_attr[i] = v.get(entity_attribute, None)
ret_with_attr.append(r_attr)

ret = ret_with_attr

result[data_path] = list(ret)[offset_limit]


return result

def return_clause(self, clause):
Expand Down Expand Up @@ -606,7 +668,7 @@ def _is_limit(self, count):
# Check if limit reached
return self._limit and count >= (self._limit + self._skip)

def _edge_hop_motifs(self, motif: nx.DiGraph) -> List[Tuple[nx.Graph, dict]]:
def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]]:
"""generate a list of edge-hop-expanded motif with edge-hop-map.
Arguments:
Expand All @@ -618,29 +680,29 @@ def _edge_hop_motifs(self, motif: nx.DiGraph) -> List[Tuple[nx.Graph, dict]]:
where a real edge path can have more than 2 element (hop >= 2)
or it can have 2 same element (hop = 0).
"""
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
for n in motif.nodes:
if motif.out_degree(n) == 0 and motif.in_degree(n) == 0:
new_motif.add_node(n, **motif.nodes[n])
motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})]
for u, v in motif.edges:
for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
new_motifs = []
min_hop = motif.edges[u, v]["__min_hop__"]
max_hop = motif.edges[u, v]["__max_hop__"]
edge_type = motif.edges[u, v]["__labels__"]
min_hop = motif.edges[u, v, k]["__min_hop__"]
max_hop = motif.edges[u, v, k]["__max_hop__"]
edge_type = motif.edges[u, v, k]["__labels__"]
hops = []
if min_hop == 0:
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
new_motif.add_node(u, **motif.nodes[u])
new_motifs.append((new_motif, {(u, v): (u, u)}))
elif min_hop >= 1:
for _ in range(1, min_hop):
hops.append(shortuuid())
for _ in range(max(min_hop, 1), max_hop):
new_edges = [u] + hops + [v]
new_motif = nx.DiGraph()
new_motif = nx.MultiDiGraph()
new_motif.add_edges_from(
list(zip(new_edges[:-1], new_edges[1:])), __labels__=edge_type
zip(new_edges, new_edges[1:]), __labels__=edge_type
)
new_motif.add_node(u, **motif.nodes[u])
new_motif.add_node(v, **motif.nodes[v])
Expand Down
123 changes: 123 additions & 0 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,129 @@ def test_order_by_with_non_returned_field(self):
assert res["n.name"] == ["Carol", "Alice", "Bob"]


class TestMultigraphRelations:
def test_query_with_multiple_relations(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Charlie", age=25)
host.add_node("d", name="Diana", age=25)

# Adding edges with labels for different types of relationship_type
host.add_edge("a", "b", __labels__={"friends"})
host.add_edge("a", "b", __labels__={"colleagues"})
host.add_edge("a", "c", __labels__={"colleagues"})
host.add_edge("b", "d", __labels__={"family"})
host.add_edge("c", "d", __labels__={"family"})
host.add_edge("c", "d", __labels__={"friends"})
host.add_edge("d", "a", __labels__={"friends"})
host.add_edge("d", "a", __labels__={"colleagues"})

qry = """
MATCH (n)-[r:friends]->(m)
RETURN n.name, m.name
"""
res = GrandCypher(host).run(qry)
assert res["n.name"] == ['Alice', 'Charlie', 'Diana']
assert res["m.name"] == ['Bob', 'Diana', 'Alice']

def test_multiple_edges_specific_attribute(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=30)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"colleague"}, years=3)
host.add_edge("a", "b", __labels__={"friend"}, years=5)
host.add_edge("a", "b", __labels__={"enemy"}, hatred=10)

qry = """
MATCH (a)-[r:friend]->(b)
RETURN a.name, b.name, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice"]
assert res["b.name"] == ["Bob"]
assert res["r.years"] == [{0: 3, 1: 5, 2: None}] # should return None when attr is missing

def test_edge_directionality(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_edge("a", "b", __labels__={"friend"}, years=1)
host.add_edge("b", "a", __labels__={"colleague"}, years=2)
host.add_edge("b", "a", __labels__={"mentor"}, years=4)

qry = """
MATCH (a)-[r]->(b)
RETURN a.name, b.name, r.__labels__, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Bob", "Alice"]
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}]


def test_query_with_missing_edge_attribute(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=30)
host.add_node("b", name="Bob", age=40)
host.add_node("c", name="Charlie", age=50)
host.add_edge("a", "b", __labels__={"friend"}, years=3)
host.add_edge("a", "c", __labels__={"colleague"}, years=10)
host.add_edge("b", "c", __labels__={"colleague"}, duration=10)
host.add_edge("b", "c", __labels__={"mentor"}, years=2)

qry = """
MATCH (a)-[r:colleague]->(b)
RETURN a.name, b.name, r.duration
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.duration"] == [{0: None}, {0: 10, 1: None}] # should return None when attr is missing

qry = """
MATCH (a)-[r:colleague]->(b)
RETURN a.name, b.name, r.years
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Charlie", "Charlie"]
assert res["r.years"] == [{0: 10}, {0: None, 1: 2}]

qry = """
MATCH (a)-[r]->(b)
RETURN a.name, b.name, r.__labels__, r.duration
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ['Alice', 'Alice', 'Bob']
assert res["b.name"] == ['Bob', 'Charlie', 'Charlie']
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.duration"] == [{0: None}, {0: None}, {0: 10, 1: None}]

def test_multigraph_single_edge_where(self):
host = nx.MultiDiGraph()
host.add_node("a", name="Alice", age=25)
host.add_node("b", name="Bob", age=30)
host.add_node("c", name="Christine", age=30)
host.add_edge("a", "b", __labels__={"friend"}, years=1, friendly="very")
host.add_edge("b", "a", __labels__={"colleague"}, years=2)
host.add_edge("b", "a", __labels__={"mentor"}, years=4)
host.add_edge("b", "c", __labels__={"chef"}, years=12)

qry = """
MATCH (a)-[r]->(b)
WHERE r.friendly == "very" OR r.years == 2
RETURN a.name, b.name, r.__labels__, r.years, r.friendly
"""
res = GrandCypher(host).run(qry)
assert res["a.name"] == ["Alice", "Bob"]
assert res["b.name"] == ["Bob", "Alice"]
assert res["r.__labels__"] == [{0: {'friend'}}, {0: {'colleague'}, 1: {'mentor'}}]
assert res["r.years"] == [{0: 1}, {0: 2, 1: 4}]
assert res["r.friendly"] == [{0: 'very'}, {0: None, 1: None}]


class TestVariableLengthRelationship:
def test_single_variable_length_relationship(self):
host = nx.DiGraph()
Expand Down

0 comments on commit 78173a3

Please sign in to comment.