Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add start of compiling expressions to strings #48

Merged
merged 15 commits into from
Sep 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,7 @@ TODO

.ipynb_checkpoints
Source.*
2
3
4
inlined
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
exclude: ^python/tests/__snapshots__/
default_language_version:
python: python3.10
ci:
Expand Down
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ name = "egglog.bindings"

[patch.'https://github.com/egraphs-good/egraph-serialize']

egraph-serialize = { git = "https://github.com/saulshanabrook/egraph-serialize", rev = "8dc11836fdade20bf2b20c5f1e2d75bcd005767d" }
egraph-serialize = { git = "https://github.com/saulshanabrook/egraph-serialize", rev = "a3f6fef9b958a335367d80d51e028c6db886fb6e" }
11 changes: 11 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,17 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea

## Unreleased

### Breaking Changes

### New Features

- Add ability to pass `seminaive` flag to Egraph to replicate `--naive` CLI flag [#48](https://github.com/metadsl/egglog-python/pull/48)
- Add ability to inline leaves $n$ times instead of just once for visualization [#48](https://github.com/metadsl/egglog-python/pull/48)

### Bug fixes

### Uncategorized

- Added initial supported for Python objects [#31](https://github.com/metadsl/egglog-python/pull/31)

- Renamed `BaseExpr` to `Expr` for succinctness
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ dependencies = ["typing-extensions", "black", "graphviz"]
[project.optional-dependencies]
dev = ["pre-commit", "black", "mypy", "flake8", "isort"]

test = ["pytest", "mypy", "scikit-learn", "array_api_compat"]
test = ["pytest", "mypy", "scikit-learn", "array_api_compat", "syrupy"]

docs = [
"pydata-sphinx-theme",
Expand Down Expand Up @@ -61,6 +61,7 @@ strict_equality = true
warn_unused_configs = true
allow_redefinition = true
enable_incomplete_feature = ["Unpack", "TypeVarTuple"]
exclude = ["__snapshots__", "_build", "^conftest.py$"]

[tool.maturin]
python-source = "python"
Expand All @@ -70,3 +71,4 @@ addopts = ["--import-mode=importlib"]
testpaths = ["python"]
python_files = ["test_*.py", "test.py"]
markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"]
norecursedirs = ["__snapshots__"]
2 changes: 1 addition & 1 deletion python/egglog/bindings.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class EGraph:
*,
max_functions: Optional[int] = None,
max_calls_per_function: Optional[int] = None,
inline_leaves: bool = False,
n_inline_leaves: int = 0,
) -> str: ...
def save_object(self, __o: object, /) -> _Expr: ...
def load_object(self, __e: _Expr, /) -> object: ...
Expand Down
24 changes: 18 additions & 6 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import inspect
from abc import ABC, abstractmethod
from contextvars import ContextVar, Token
from copy import deepcopy
Expand Down Expand Up @@ -104,12 +105,12 @@ class _BaseModule(ABC):
"""

# Any modules you want to depend on
deps: InitVar[list[Module]] = []
modules: InitVar[list[Module]] = []
# All dependencies flattened
_flatted_deps: list[Module] = field(init=False, default_factory=list)
_mod_decls: ModuleDeclarations = field(init=False)

def __post_init__(self, modules: list[Module] = []) -> None:
def __post_init__(self, modules: list[Module]) -> None:
included_decls = [_BUILTIN_DECLS] if _BUILTIN_DECLS else []
# Traverse all the included modules to flatten all their dependencies and add to the included declerations
for mod in modules:
Expand Down Expand Up @@ -617,7 +618,7 @@ def let(self, name: str, expr: EXPR) -> EXPR:

@dataclass
class _Builtins(_BaseModule):
def __post_init__(self, modules: list[Module] = []) -> None:
def __post_init__(self, modules: list[Module]) -> None:
"""
Register these declarations as builtins, so others can use them.
"""
Expand Down Expand Up @@ -663,13 +664,16 @@ class EGraph(_BaseModule):
Represents an EGraph instance at runtime
"""

_egraph: bindings.EGraph = field(repr=False, default_factory=bindings.EGraph)
seminaive: InitVar[bool] = True

_egraph: bindings.EGraph = field(repr=False, init=False)
# The current declarations which have been pushed to the stack
_decl_stack: list[Declarations] = field(default_factory=list, repr=False)
_token: Optional[Token[EGraph]] = None

def __post_init__(self, modules: list[Module] = []) -> None:
def __post_init__(self, modules: list[Module], seminaive) -> None:
super().__post_init__(modules)
self._egraph = bindings.EGraph(seminaive=seminaive)
for m in self._flatted_deps:
for o in m._py_objects:
self._egraph.save_object(o)
Expand Down Expand Up @@ -1430,7 +1434,15 @@ def _command_generator(gen: CommandGenerator) -> Iterable[Command]:
"""
Calls the function with variables of the type and name of the arguments.
"""
hints = get_type_hints(gen)
# Get the local scope from where the function is defined, so that we can get any type hints that are in the scope
# but not in the globals
current_frame = inspect.currentframe()
assert current_frame
register_frame = current_frame.f_back
assert register_frame
original_frame = register_frame.f_back
assert original_frame
hints = get_type_hints(gen, gen.__globals__, original_frame.f_locals)
args = (_var(p.name, hints[p.name]) for p in signature(gen).parameters.values())
return gen(*args)

Expand Down
6 changes: 4 additions & 2 deletions python/egglog/examples/sklearn_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@ def fit(X, y):

with EGraph([array_api_module]) as egraph:
egraph.register(res)
egraph.run((run() * 10).saturate())
egraph.run((run() * 10).saturate())
egraph.run((run() * 10))
# egraph.run((run() * 10).saturate())
# egraph.graphviz(n_inline_leaves=3).render("3", view=True)

res = egraph.extract(expr=res)
print(res)
# egraph.display()
Loading
Loading