Skip to content

Commit

Permalink
Adds warning when converting graphs to MultiDiGraph, and only convert…
Browse files Browse the repository at this point in the history
…s when necessary (#53)
  • Loading branch information
jackboyla authored Nov 17, 2024
1 parent 5e29751 commit af754de
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
14 changes: 12 additions & 2 deletions grandcypher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
from collections import OrderedDict
import random
import string
import logging
from functools import lru_cache
import networkx as nx

import grandiso

from lark import Lark, Transformer, v_args, Token, Tree

logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)


_OPERATORS = {
"=": lambda x, y: x == y,
Expand Down Expand Up @@ -162,7 +169,7 @@
start="start",
)

__version__ = "0.10.0"
__version__ = "0.10.1"


_ALPHABET = string.ascii_lowercase + string.digits
Expand Down Expand Up @@ -390,7 +397,10 @@ def _data_path_to_entity_name_attribute(data_path):

class _GrandCypherTransformer(Transformer):
def __init__(self, target_graph: nx.Graph, limit=None):
self._target_graph = nx.MultiDiGraph(target_graph)
self._target_graph = target_graph
if not isinstance(self._target_graph, nx.MultiDiGraph):
self._target_graph = nx.MultiDiGraph(target_graph)
logger.warning("Converting graph to MultiDiGraph")
self._entity2alias = dict()
self._alias2entity = dict()
self._paths = []
Expand Down
15 changes: 15 additions & 0 deletions grandcypher/test_queries.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import networkx as nx
import pytest
import logging

from . import _GrandCypherGrammar, _GrandCypherTransformer, GrandCypher

Expand Down Expand Up @@ -58,6 +59,20 @@ def test_simple_structural_match_returns_node_attributes(self, graph_type):
assert len(returns["A.dinnertime"]) == 2


@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_warning_for_non_multidigraph(self, graph_type, caplog):
host = graph_type()

with caplog.at_level(logging.WARNING):
gct = GrandCypher(host)

if isinstance(host, nx.MultiDiGraph):
assert len(caplog.records) == 0
elif isinstance(host, nx.DiGraph):
assert len(caplog.records) == 1
assert caplog.records[0].levelname == "WARNING"


class TestSimpleAPI:
@pytest.mark.parametrize("graph_type", ACCEPTED_GRAPH_TYPES)
def test_simple_api(self, graph_type):
Expand Down

0 comments on commit af754de

Please sign in to comment.