Skip to content

Commit

Permalink
Merge pull request #242 from egraphs-good/fix-loopnest
Browse files Browse the repository at this point in the history
Working loopnest example
  • Loading branch information
saulshanabrook authored Dec 6, 2024
2 parents 9bb83d4 + 5c73999 commit 8b17580
Show file tree
Hide file tree
Showing 7 changed files with 64 additions and 82 deletions.
3 changes: 3 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ _This project uses semantic versioning_
- Fix pretty printing of lambda functions
- Add support for subsuming rewrite generated by default function and method definitions
- Add better error message when using @function in class (thanks @shinawy)
- Add error method if `@method` decorator is in wrong place
- Subsumes lambda functions after replacing
- Add working loopnest test

## 8.0.1 (2024-10-24)

Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ filterwarnings = [
"error",
"ignore::numba.core.errors.NumbaPerformanceWarning",
"ignore::pytest_benchmark.logger.PytestBenchmarkWarning",
# https://github.com/manzt/anywidget/blob/d38bb3f5f9cfc7e49e2ff1aa1ba994d66327cb02/pyproject.toml#L120
"ignore:Deprecated in traitlets 4.1, use the instance .metadata:DeprecationWarning",
]

[tool.coverage.report]
Expand Down
3 changes: 2 additions & 1 deletion python/egglog/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ def _convert_function(a: FunctionType) -> UnstableFn:
transformed_fn = functionalize(a, value_to_annotation)
assert isinstance(transformed_fn, partial)
return UnstableFn(
function(ruleset=get_current_ruleset(), use_body_as_name=True)(transformed_fn.func), *transformed_fn.args
function(ruleset=get_current_ruleset(), use_body_as_name=True, subsume=True)(transformed_fn.func),
*transformed_fn.args,
)


Expand Down
8 changes: 8 additions & 0 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,10 @@ def _generate_class_decls( # noqa: C901,PLR0912
fn = fn.fget
case _:
ref = InitRef(cls_name) if is_init else MethodRef(cls_name, method_name)
if isinstance(fn, _WrappedMethod):
msg = f"{cls_name}.{method_name} Add the @method(...) decorator above @classmethod or @property"

raise ValueError(msg) # noqa: TRY004
special_function_name: SpecialFunctions | None = (
"fn-partial" if egg_fn == "unstable-fn" else "fn-app" if egg_fn == "unstable-app" else None
)
Expand Down Expand Up @@ -1373,10 +1377,14 @@ def saturate(
"""
Saturate the egraph, running the given schedule until the egraph is saturated.
It serializes the egraph at each step and returns a widget to visualize the egraph.
If an `expr` is passed, it's also extracted after each run and printed
"""
from .visualizer_widget import VisualizerWidget

def to_json() -> str:
if expr is not None:
print(self.extract(expr), "\n")
return self._serialize(**kwargs).to_json()

egraphs = [to_json()]
Expand Down
13 changes: 8 additions & 5 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def single(cls, i: Int) -> TupleInt:
return TupleInt(Int(1), lambda _: i)

@classmethod
def range(cls, stop: Int) -> TupleInt:
def range(cls, stop: IntLike) -> TupleInt:
return TupleInt(stop, lambda i: i)

@classmethod
Expand Down Expand Up @@ -346,7 +346,6 @@ def _tuple_int(
ti: TupleInt,
ti2: TupleInt,
):
remaining = TupleInt(k - 1, lambda i: idx_fn(i + 1)).filter(filter_f)
return [
rewrite(TupleInt(i, idx_fn).length()).to(i),
rewrite(TupleInt(i, idx_fn)[i2]).to(idx_fn(i2)),
Expand All @@ -367,7 +366,11 @@ def _tuple_int(
# filter TODO: could be written as fold w/ generic types
rewrite(TupleInt(0, idx_fn).filter(filter_f)).to(TupleInt(0, idx_fn)),
rewrite(TupleInt(Int(k), idx_fn).filter(filter_f)).to(
TupleInt.if_(filter_f(value := idx_fn(Int(k))), TupleInt.single(value) + remaining, remaining),
TupleInt.if_(
filter_f(value := idx_fn(Int(k - 1))),
(remaining := TupleInt(k - 1, idx_fn).filter(filter_f)) + TupleInt.single(value),
remaining,
),
ne(k).to(i64(0)),
),
# Empty
Expand All @@ -386,13 +389,13 @@ def var(cls, name: StringLike) -> TupleTupleInt: ...

def __init__(self, length: IntLike, idx_fn: Callable[[Int], TupleInt]) -> None: ...

@classmethod
@method(subsume=True)
@classmethod
def single(cls, i: TupleInt) -> TupleTupleInt:
return TupleTupleInt(Int(1), lambda _: i)

@classmethod
@method(subsume=True)
@classmethod
def from_vec(cls, vec: Vec[Int]) -> TupleInt:
return TupleInt(vec.length(), partial(index_vec_int, vec))

Expand Down
75 changes: 2 additions & 73 deletions python/egglog/exp/array_api_loopnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from egglog import *
from egglog.exp.array_api import *

__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]


class ShapeAPI(Expr):
def __init__(self, dims: TupleIntLike) -> None: ...
Expand Down Expand Up @@ -105,76 +107,3 @@ def _loopnest_api_ruleset(
yield rewrite(lna.indices, subsume=True).to(
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


# %%
# egraph = EGraph(save_egglog_string=True)

# egraph.register(val.shape)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract_multiple(val.shape, 10)

# %%

X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))
egraph = EGraph()
x = egraph.let("x", val.shape[2])
# egraph.display(n_inline_leaves=0)
# egraph.extract(x)
# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract(x)
# egraph.display()


# %%

# x = xs[-2]
# # %%
# decls = x.__egg_decls__
# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])

# # %%
# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %%

# # %%
# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]))


# from egglog import pretty

# decl = (
# x.__egg_typed_expr__.expr.args[1]
# .expr.args[2]
# .expr.args[0]
# .expr.args[1]
# .expr.call.args[0]
# .expr.call.args[0]
# .expr.call.args[0]
# )

# # pprint.pprint(decl)

# print(pretty.pretty_decl(decls, decl.expr))

# # %%
42 changes: 39 additions & 3 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from sklearn import config_context, datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from egglog.egraph import set_current_ruleset
from egglog.exp.array_api import *
from egglog.exp.array_api_jit import jit
from egglog.exp.array_api_loopnest import *
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *

Expand Down Expand Up @@ -68,6 +70,41 @@ def test_reshape_vec_noop():
egraph.check(eq(res).to(x))


def test_filter():
with set_current_ruleset(array_api_ruleset):
x = TupleInt.range(5).filter(lambda i: i < 2).length()
check_eq(x, Int(2), array_api_schedule)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


class TestLoopNest:
def test_shape(self):
X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))

check_eq(val.shape.length(), Int(2), array_api_schedule)
check_eq(val.shape[0], Int(3), array_api_schedule)
check_eq(val.shape[1], Int(4), array_api_schedule)


# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
# The next step will load the old snapshot and run their test on it.

Expand All @@ -80,7 +117,6 @@ def run_lda(x, y):

iris = datasets.load_iris()
X_np, y_np = (iris.data, iris.target)
res_np = run_lda(X_np, y_np)


def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
Expand Down Expand Up @@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
optimized_expr = simplify_lda(egraph, expr)
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
py_object = benchmark(load_source, fn_program, egraph)
assert np.allclose(py_object(X_np, y_np), res_np)
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
assert egraph.eval(fn_program.statements) == snapshot_py

@pytest.mark.parametrize(
Expand All @@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
)
def test_execution(self, fn, benchmark):
# warmup once for numba
assert np.allclose(res_np, fn(X_np, y_np))
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))
benchmark(fn, X_np, y_np)


Expand Down

0 comments on commit 8b17580

Please sign in to comment.