Skip to content

Commit

Permalink
Support for or-join
Browse files Browse the repository at this point in the history
  • Loading branch information
Donnype committed Jun 10, 2023
1 parent 841915d commit 97b524a
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
15 changes: 15 additions & 0 deletions tests/test_datalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Limit,
NotJoin,
OrderBy,
OrJoin,
Sample,
Sum,
Timeout,
Expand Down Expand Up @@ -269,3 +270,17 @@ def test_not_join():

statement = Where("e", "xt/id") & NotJoin("e") & Where("e", "last-name", "n") & Where("e", "name", "n")
assert statement.compile() == ":where [ (not-join [e] ) [ e :last-name n ] [ e :name n ] [ e :xt/id ]]"


def test_or_join():
statement = Where("e", "xt/id") & (OrJoin("e") & Where("e", "last-name", "n"))
assert statement.compile() == ":where [ (or-join [e] [ e :last-name n ]) [ e :xt/id ]]"

statement = Where("e", "xt/id") & (OrJoin("e") & (Where("e", "last-name", "n") & Where("e", "name", "n")))
assert statement.compile() == ":where [ (or-join [e] [ e :last-name n ] [ e :name n ]) [ e :xt/id ]]"

statement = Where("e", "xt/id") & (OrJoin("e") & Where("e", "last-name", "n") & Where("e", "name", "n"))
assert statement.compile() == ":where [ (or-join [e] [ e :last-name n ] [ e :name n ]) [ e :xt/id ]]"

statement = Where("e", "xt/id") & OrJoin("e") & Where("e", "last-name", "n") & Where("e", "name", "n")
assert statement.compile() == ":where [ (or-join [e] ) [ e :last-name n ] [ e :name n ] [ e :xt/id ]]"
29 changes: 29 additions & 0 deletions xtdb/datalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,35 @@ def __invert__(self):
raise XTDBException("Cannot use ~ on not-join")


class OrJoin(Clause):
def __init__(self, variable: str, clauses: Optional[List] = None):
self.variable = variable
self.clauses = clauses or []

def compile(self, root: bool = True, *, separator=" ") -> str:
collected = []

for clause in self.clauses:
collected.append(clause.compile(root=False, separator=separator))

if all(clause.idempotent for clause in self.clauses):
collected = list(set(collected))

if all(clause.commutative for clause in self.clauses):
collected = sorted(collected)

if root:
return f":where [(or-join{separator}[{self.variable}] {separator.join(collected)})]"

return f"(or-join{separator}[{self.variable}] {separator.join(collected)})"

def _and(self, other: Clause) -> Clause:
return OrJoin(self.variable, self.clauses + [other])

def __invert__(self):
raise XTDBException("Cannot use ~ on or-join")


class Where(Clause):
def __init__(self, document: str, field: str, value: Any = ""):
self.document = document
Expand Down

0 comments on commit 97b524a

Please sign in to comment.