Skip to content

Commit

Permalink
Add working loopnest example
Browse files Browse the repository at this point in the history
  • Loading branch information
saulshanabrook committed Dec 6, 2024
1 parent 7848e74 commit 5c73999
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 76 deletions.
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ _This project uses semantic versioning_
- Add better error message when using @function in class (thanks @shinawy)
- Add error method if `@method` decorator is in wrong place
- Subsumes lambda functions after replacing
- Add working loopnest test

## 8.0.1 (2024-10-24)

Expand Down
75 changes: 2 additions & 73 deletions python/egglog/exp/array_api_loopnest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from egglog import *
from egglog.exp.array_api import *

__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"]


class ShapeAPI(Expr):
def __init__(self, dims: TupleIntLike) -> None: ...
Expand Down Expand Up @@ -105,76 +107,3 @@ def _loopnest_api_ruleset(
yield rewrite(lna.indices, subsume=True).to(
tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range))
)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


# %%
# egraph = EGraph(save_egglog_string=True)

# egraph.register(val.shape)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract_multiple(val.shape, 10)

# %%

X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))
egraph = EGraph()
x = egraph.let("x", val.shape[2])
# egraph.display(n_inline_leaves=0)
# egraph.extract(x)
# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0)
# egraph.run(array_api_ruleset.saturate())
# egraph.extract(x)
# egraph.display()


# %%

# x = xs[-2]
# # %%
# decls = x.__egg_decls__
# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])

# # %%
# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %%

# # %%
# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]))


# from egglog import pretty

# decl = (
# x.__egg_typed_expr__.expr.args[1]
# .expr.args[2]
# .expr.args[0]
# .expr.args[1]
# .expr.call.args[0]
# .expr.call.args[0]
# .expr.call.args[0]
# )

# # pprint.pprint(decl)

# print(pretty.pretty_decl(decls, decl.expr))

# # %%
42 changes: 39 additions & 3 deletions python/tests/test_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
from sklearn import config_context, datasets
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from egglog.egraph import set_current_ruleset
from egglog.exp.array_api import *
from egglog.exp.array_api_jit import jit
from egglog.exp.array_api_loopnest import *
from egglog.exp.array_api_numba import array_api_numba_schedule
from egglog.exp.array_api_program_gen import *

Expand Down Expand Up @@ -68,6 +70,41 @@ def test_reshape_vec_noop():
egraph.check(eq(res).to(x))


def test_filter():
with set_current_ruleset(array_api_ruleset):
x = TupleInt.range(5).filter(lambda i: i < 2).length()
check_eq(x, Int(2), array_api_schedule)


@function(ruleset=array_api_ruleset, subsume=True)
def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray:
# peel off the outer shape for result array
outshape = ShapeAPI(X.shape).deselect(axis).to_tuple()
# get only the inner shape for reduction
reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple()

return NDArray(
outshape,
X.dtype,
lambda k: sqrt(
LoopNestAPI.from_tuple(reduce_axis)
.unwrap()
.fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0)
).to_value(),
)


class TestLoopNest:
def test_shape(self):
X = NDArray.var("X")
assume_shape(X, (3, 2, 3, 4))
val = linalg_norm(X, (0, 1))

check_eq(val.shape.length(), Int(2), array_api_schedule)
check_eq(val.shape[0], Int(3), array_api_schedule)
check_eq(val.shape[1], Int(4), array_api_schedule)


# This test happens in different steps. Each will be benchmarked and saved as a snapshot.
# The next step will load the old snapshot and run their test on it.

Expand All @@ -80,7 +117,6 @@ def run_lda(x, y):

iris = datasets.load_iris()
X_np, y_np = (iris.data, iris.target)
res_np = run_lda(X_np, y_np)


def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any:
Expand Down Expand Up @@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
optimized_expr = simplify_lda(egraph, expr)
fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y"))
py_object = benchmark(load_source, fn_program, egraph)
assert np.allclose(py_object(X_np, y_np), res_np)
assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np))
assert egraph.eval(fn_program.statements) == snapshot_py

@pytest.mark.parametrize(
Expand All @@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark):
)
def test_execution(self, fn, benchmark):
# warmup once for numba
assert np.allclose(res_np, fn(X_np, y_np))
assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np))
benchmark(fn, X_np, y_np)


Expand Down

0 comments on commit 5c73999

Please sign in to comment.