Skip to content

Commit

Permalink
Merge pull request #46 from metadsl/try-pr
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook authored Sep 14, 2023
2 parents f3cddb4 + 1c1b231 commit a020cd4
Show file tree
Hide file tree
Showing 18 changed files with 148 additions and 133 deletions.
21 changes: 12 additions & 9 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ jobs:
- "3.9"
- "3.8"
steps:
- uses: actions/checkout@v4
- name: Setup python ${{ matrix.py }}
uses: actions/setup-python@v2
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.py }}
- uses: actions/checkout@v2
cache: "pip"
- name: Cache cargo
uses: actions/cache@v3
with:
Expand All @@ -48,10 +49,11 @@ jobs:
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: actions/checkout@v2
cache: "pip"
- name: Cache cargo
uses: actions/cache@v3
with:
Expand All @@ -68,10 +70,11 @@ jobs:
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/setup-python@v2
- uses: actions/checkout@v4
- uses: actions/setup-python@v4
with:
python-version: "3.10"
- uses: actions/checkout@v2
cache: "pip"
- name: Install graphviz
run: |
sudo apt-get update
Expand All @@ -93,7 +96,7 @@ jobs:
runs-on: ubuntu-latest
needs: [test]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: PyO3/maturin-action@v1.40.1
with:
manylinux: auto
Expand All @@ -110,7 +113,7 @@ jobs:
runs-on: windows-latest
needs: [test]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: PyO3/maturin-action@v1.40.1
with:
command: build
Expand All @@ -126,7 +129,7 @@ jobs:
runs-on: macos-latest
needs: [test]
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: PyO3/maturin-action@v1.40.1
with:
command: build
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ crate-type = ["cdylib"]

[dependencies]
pyo3 = { version = "0.18.1", features = ["extension-module"] }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "c83fc750878755eb610a314da90f9273b3bfe25d" }
egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4d67f262a6f27aa5cfb62a2cfc7df968959105df" }
# egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" }
# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "c01695618ed4de2fbfa8116476e208bc1ca86612" }

Expand Down
6 changes: 6 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,18 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea

## Unreleased

- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...4d67f262a6f27aa5cfb62a2cfc7df968959105df)

### Breaking Changes

- Switches `RunReport` to include more granular timings

### New Features

- Add ability to pass `seminaive` flag to Egraph to replicate `--naive` CLI flag [#48](https://github.com/metadsl/egglog-python/pull/48)
- Add ability to inline leaves $n$ times instead of just once for visualization [#48](https://github.com/metadsl/egglog-python/pull/48)
- Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46)
- Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46)

### Bug fixes

Expand Down
12 changes: 6 additions & 6 deletions docs/tutorials/getting-started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,12 @@
"A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n",
"# Set each to be a square matrix of the given dimension\n",
"egraph.register(\n",
" set_(A.nrows()).to(n),\n",
" set_(A.ncols()).to(n),\n",
" set_(B.nrows()).to(m),\n",
" set_(B.ncols()).to(m),\n",
" set_(C.nrows()).to(p),\n",
" set_(C.ncols()).to(p),\n",
" union(A.nrows()).with_(n),\n",
" union(A.ncols()).with_(n),\n",
" union(B.nrows()).with_(m),\n",
" union(B.ncols()).with_(m),\n",
" union(C.nrows()).with_(p),\n",
" union(C.ncols()).with_(p),\n",
")\n",
"# Create an example which should equal the kronecker product of A and B\n",
"ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
Expand Down
29 changes: 23 additions & 6 deletions python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -232,16 +232,20 @@ class IdentSort:
@final
class RunReport:
updated: bool
search_time: timedelta
apply_time: timedelta
rebuild_time: timedelta
search_time_per_rule: dict[str, timedelta]
apply_time_per_rule: dict[str, timedelta]
search_time_per_ruleset: dict[str, timedelta]
apply_time_per_ruleset: dict[str, timedelta]
rebuild_time_per_ruleset: dict[str, timedelta]

def __init__(
self,
updated: bool,
search_time: timedelta,
apply_time: timedelta,
rebuild_time: timedelta,
search_time_per_rule: dict[str, timedelta],
apply_time_per_rule: dict[str, timedelta],
search_time_per_ruleset: dict[str, timedelta],
apply_time_per_ruleset: dict[str, timedelta],
rebuild_time_per_ruleset: dict[str, timedelta],
) -> None: ...

@final
Expand Down Expand Up @@ -425,6 +429,17 @@ class Include:
class CheckProof:
def __init__(self) -> None: ...

@final
class Relation:
constructor: str
inputs: list[str]

def __init__(self, constructor: str, inputs: list[str]) -> None: ...

@final
class PrintOverallStatistics:
def __init__(self) -> None: ...

_Command = (
SetOption
| Datatype
Expand All @@ -450,6 +465,8 @@ _Command = (
| Fail
| Include
| CheckProof
| Relation
| PrintOverallStatistics
)

def termdag_term_to_expr(termdag: TermDag, term: _Term) -> _Expr: ...
9 changes: 9 additions & 0 deletions python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ class String(Expr):
def __init__(self, value: str):
...

@BUILTINS.method(egg_fn="replace")
def replace(self, old: StringLike, new: StringLike) -> String: # type: ignore[empty-body]
...


@BUILTINS.function(egg_fn="+")
def join(*strings: StringLike) -> String: # type: ignore[empty-body]
Expand Down Expand Up @@ -150,6 +154,11 @@ def to_string(self) -> String: # type: ignore[empty-body]
converter(int, i64, i64)


@BUILTINS.function(egg_fn="count-matches")
def count_matches(s: StringLike, pattern: StringLike) -> i64: # type: ignore[empty-body]
...


f64Like = Union[float, "f64"]


Expand Down
10 changes: 8 additions & 2 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def register_function_callable(
default: Optional[ExprDecl],
merge: Optional[ExprDecl],
merge_action: list[bindings._Action],
is_relation: bool = False,
) -> Iterable[bindings._Command]:
"""
Registers a callable with the given egg name. The callable's function needs to be registered
Expand All @@ -296,7 +297,7 @@ def register_function_callable(
egg_name = egg_name or ref.generate_egg_name()
self._decl.register_callable_ref(ref, egg_name)
self._decl.set_function_decl(ref, fn_decl)
return fn_decl.to_commands(self, egg_name, cost, default, merge, merge_action)
return fn_decl.to_commands(self, egg_name, cost, default, merge, merge_action, is_relation)

def register_constant_callable(
self, ref: ConstantCallableRef, type_ref: JustTypeRef, egg_name: Optional[str]
Expand Down Expand Up @@ -487,6 +488,7 @@ def to_commands(
default: Optional[ExprDecl] = None,
merge: Optional[ExprDecl] = None,
merge_action: list[bindings._Action] = [],
is_relation: bool = False,
) -> Iterable[bindings._Command]:
if self.var_arg_type is not None:
raise NotImplementedError("egglog does not support variable arguments yet.")
Expand All @@ -499,7 +501,11 @@ def to_commands(
arg_sorts.append(arg_sort)
return_sort, cmds = mod_decls.register_sort(self.return_type.to_just())
yield from cmds

if is_relation:
assert not default and not merge and not merge_action and not cost
assert return_sort == "Unit"
yield bindings.Relation(egg_name, arg_sorts)
return
egg_fn_decl = bindings.FunctionDecl(
egg_name,
bindings.Schema(arg_sorts, return_sort),
Expand Down
9 changes: 8 additions & 1 deletion python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,7 +571,14 @@ def relation(self, name: str, /, *tps: type, egg_fn: Optional[str] = None) -> Ca
arg_types, None, tuple(None for _ in tps), TypeRefWithVars("Unit"), mutates_first_arg=False
)
commands = self._mod_decls.register_function_callable(
FunctionRef(name), fn_decl, egg_fn, cost=None, default=None, merge=None, merge_action=[]
FunctionRef(name),
fn_decl,
egg_fn,
cost=None,
default=None,
merge=None,
merge_action=[],
is_relation=True,
)
self._process_commands(commands)
return cast(Callable[..., Unit], RuntimeFunction(self._mod_decls, name))
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/examples/lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def freer(t: Term) -> StringSet:
set_(freer(t)).to(fv1 | fv2 | fv3)
),
# eval
rule(eq(t).to(Term.val(v))).then(set_(t.eval()).to(v)),
rule(eq(t).to(Term.val(v))).then(union(t.eval()).with_(v)),
rule(eq(t).to(t1 + t2), eq(Val(i1)).to(t1.eval()), eq(Val(i2)).to(t2.eval())).then(
union(t.eval()).with_(Val(i1 + i2))
),
Expand Down
12 changes: 6 additions & 6 deletions python/egglog/examples/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def kron(a: Matrix, b: Matrix) -> Matrix: # type: ignore[empty-body]

# Set each to be a square matrix of the given dimension
egraph.register(
set_(A.nrows()).to(n),
set_(A.ncols()).to(n),
set_(B.nrows()).to(m),
set_(B.ncols()).to(m),
set_(C.nrows()).to(p),
set_(C.ncols()).to(p),
union(A.nrows()).with_(n),
union(A.ncols()).with_(n),
union(B.nrows()).with_(m),
union(B.ncols()).with_(m),
union(C.nrows()).with_(p),
union(C.ncols()).with_(p),
)
# Create an example which should equal the kronecker product of A and B
ex1 = egraph.let("ex1", kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m)))
Expand Down
14 changes: 7 additions & 7 deletions python/egglog/examples/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __invert__(self) -> Bool: # type: ignore[empty-body]
p, a, b, c, as_, bs = vars_("p a b c as bs", Bool)
egraph.register(
# clauses are assumed in the normal form (or a (or b (or c False)))
set_(~F).to(T),
set_(~T).to(F),
union(~F).with_(T),
union(~T).with_(F),
# "Solving" negation equations
rule(eq(~p).to(T)).then(union(p).with_(F)),
rule(eq(~p).to(F)).then(union(p).with_(T)),
Expand All @@ -56,7 +56,7 @@ def __invert__(self) -> Bool: # type: ignore[empty-body]
eq(T).to(a | as_),
eq(T).to(~a | bs),
).then(
set_(as_ | bs).to(T),
union(as_ | bs).with_(T),
),
)

Expand All @@ -71,11 +71,11 @@ def pred(x: i64Like) -> Bool: # type: ignore[empty-body]
p1 = egraph.let("p1", pred(1))
p2 = egraph.let("p2", pred(2))
egraph.register(
set_(p1 | (~p2 | F)).to(T),
set_(p2 | (~p0 | F)).to(T),
set_(p0 | (~p1 | F)).to(T),
union(p1 | (~p2 | F)).with_(T),
union(p2 | (~p0 | F)).with_(T),
union(p0 | (~p1 | F)).with_(T),
union(p1).with_(F),
set_(~p0 | (~p1 | (p2 | F))).to(T),
union(~p0 | (~p1 | (p2 | F))).with_(T),
)
egraph.run(10)
egraph.check(T != F)
Expand Down
8 changes: 4 additions & 4 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

def extract_py(e: Expr) -> Any:
egraph = EGraph.current()
egraph.push()
# print(e)
egraph.register(e)
egraph.run((run() * 30).saturate())
Expand Down Expand Up @@ -60,8 +61,10 @@ def extract_py(e: Expr) -> Any:
except EggSmolError:
other_versions = egraph.extract_multiple(final_object, 10)
other_verions_str = "\n\n".join(map(str, other_versions))
egraph.graphviz().render(view=True)
raise Exception(f"Failed to extract:\n{other_verions_str}")
# print(res)
egraph.pop()
return res


Expand Down Expand Up @@ -1235,10 +1238,7 @@ def _linalg(x: NDArray, full_matrices: Bool):
# to analyze `any(((astype(unique_counts(NDArray.var("y"))[Int(1)], DType.float64) / NDArray.scalar(Value.float(Float(150.0))) < NDArray.scalar(Value.int(Int(0)))).bool()``
##


@array_api_module.function
def greater_zero(value: Value) -> Unit:
...
greater_zero = array_api_module.relation("greater_zero", Value)


# @array_api_module.function
Expand Down
Loading

0 comments on commit a020cd4

Please sign in to comment.