Skip to content

Commit

Permalink
Handles both MultiDiGraph and DiGraph without up-conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
jackboyla committed Dec 5, 2024
1 parent af754de commit 3829112
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 31 deletions.
60 changes: 43 additions & 17 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ def _is_edge_attr_match(
motif_edges = _get_edge_attributes(motif, motif_u, motif_v)
host_edges = _get_edge_attributes(host, host_u, host_v)

if not motif_edges or not host_edges:
# if there are no edges, they don't match
return False

# Aggregate all __labels__ into one set
motif_edges = _aggregate_edge_labels(motif_edges)
host_edges = _aggregate_edge_labels(host_edges)
Expand All @@ -262,9 +266,11 @@ def _get_edge_attributes(graph: Union[nx.Graph, nx.MultiDiGraph], u, v) -> Dict:
"""
Retrieve edge attributes from a graph, handling both Graph and MultiDiGraph.
"""
if isinstance(graph, nx.MultiDiGraph):
return graph[u][v]
return {0: graph[u][v]} # Mock single edge for DiGraph
if graph.is_multigraph():
return graph.get_edge_data(u, v)
else:
data = graph.get_edge_data(u, v)
return {0: data} # Wrap in dict to mimic MultiDiGraph structure


def _aggregate_edge_labels(edges: Dict) -> Dict:
Expand Down Expand Up @@ -294,24 +300,27 @@ def _get_entity_from_host(
return entity_name
else:
# looking for an edge:
edge_data = host.get_edge_data(*entity_name)
u, v = entity_name
edge_data = _get_edge_attributes(host, u, v)
if not edge_data:
return None # print(f"Nothing found for {entity_name} {entity_attribute}")

if entity_attribute:
# looking for edge attribute:
if isinstance(host, nx.MultiDiGraph):
return [r.get(entity_attribute, None) for r in edge_data.values()]
if host.is_multigraph():
# return a list of attribute values for all edges between u and v
return [attrs.get(entity_attribute) for attrs in edge_data.values()]
else:
return edge_data.get(entity_attribute, None)
# return the attribute value for the single edge
return edge_data[0].get(entity_attribute)
else:
return host.get_edge_data(*entity_name)
return edge_data


def _get_edge(host: nx.DiGraph, mapping, match_path, u, v):
def _get_edge(host: Union[nx.DiGraph, nx.MultiDiGraph], mapping, match_path, u, v):
edge_path = match_path[(u, v)]
return [
host.get_edge_data(mapping[u], mapping[v])
_get_edge_attributes(host, mapping[u], mapping[v])
for u, v in zip(edge_path[:-1], edge_path[1:])
]

Expand Down Expand Up @@ -353,11 +362,11 @@ def inner(
else:
raise IndexError(f"Entity {host_entity_id} not in graph.")

operator_results = []
if isinstance(host, nx.MultiDiGraph):
# if any of the relations between nodes satisfies condition, return True
r_vals = _get_entity_from_host(host, *host_entity_id)
r_vals = [r_vals] if not isinstance(r_vals, list) else r_vals
operator_results = []
for r_val in r_vals:
try:
operator_results.append(operator(r_val, value))
Expand All @@ -369,6 +378,7 @@ def inner(
val = operator(_get_entity_from_host(host, *host_entity_id), value)
except:
val = False
operator_results.append(val)

if val != should_be:
return False, operator_results
Expand Down Expand Up @@ -398,9 +408,6 @@ def _data_path_to_entity_name_attribute(data_path):
class _GrandCypherTransformer(Transformer):
def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = target_graph
if not isinstance(self._target_graph, nx.MultiDiGraph):
self._target_graph = nx.MultiDiGraph(target_graph)
logger.warning("Converting graph to MultiDiGraph")
self._entity2alias = dict()
self._alias2entity = dict()
self._paths = []
Expand Down Expand Up @@ -754,6 +761,15 @@ def returns(self, ignore_limit=False):
for key, values in results.items()
if self._alias2entity.get(key, key) in self._return_requests
}
# HACK: convert to [None] if edge is None
for key, values in results.items():
parsed_values = []
for v in values:
if v == [{0: None}]: # edge is None
parsed_values.append([None])
else:
parsed_values.append(v)
results[key] = parsed_values

return results

Expand Down Expand Up @@ -973,8 +989,18 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]
if motif.out_degree(n) == 0 and motif.in_degree(n) == 0:
new_motif.add_node(n, **motif.nodes[n])
motifs: List[Tuple[nx.DiGraph, dict]] = [(new_motif, {})]

if motif.is_multigraph():
edge_iter = motif.edges(keys=True)
else:
edge_iter = motif.edges(keys=False)

for u, v, k in motif.edges: # OutMultiEdgeView([('a', 'b', 0)])
for edge in edge_iter:
if motif.is_multigraph():
u, v, k = edge
else:
u, v = edge
k = 0 # Dummy key for DiGraph
new_motifs = []
min_hop = motif.edges[u, v, k]["__min_hop__"]
max_hop = motif.edges[u, v, k]["__max_hop__"]
Expand Down Expand Up @@ -1002,8 +1028,8 @@ def _edge_hop_motifs(self, motif: nx.MultiDiGraph) -> List[Tuple[nx.Graph, dict]

def _product_motifs(
self,
motifs_1: List[Tuple[nx.DiGraph, dict]],
motifs_2: List[Tuple[nx.DiGraph, dict]],
motifs_1: List[Tuple[nx.Graph, dict]],
motifs_2: List[Tuple[nx.Graph, dict]],
):
new_motifs = []
for motif_1, mapping_1 in motifs_1:
Expand Down
14 changes: 0 additions & 14 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,6 @@ def test_simple_structural_match_returns_node_attributes(self, graph_type):
assert len(returns["A.dinnertime"]) == 2


@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_warning_for_non_multidigraph(self, graph_type, caplog):
host = graph_type()

with caplog.at_level(logging.WARNING):
gct = GrandCypher(host)

if isinstance(host, nx.MultiDiGraph):
assert len(caplog.records) == 0
elif isinstance(host, nx.DiGraph):
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"


class TestSimpleAPI:
@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_simple_api(self, graph_type):
Expand Down

0 comments on commit 3829112

Please sign in to comment.