From 5c73999c0a327d41b68c8e22355139d1fc8be6c1 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Fri, 6 Dec 2024 13:00:25 -0500 Subject: [PATCH] Add working loopnest example --- docs/changelog.md | 1 + python/egglog/exp/array_api_loopnest.py | 75 +------------------------ python/tests/test_array_api.py | 42 +++++++++++++- 3 files changed, 42 insertions(+), 76 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 24d7d46..377c79c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -9,6 +9,7 @@ _This project uses semantic versioning_ - Add better error message when using @function in class (thanks @shinawy) - Add error method if `@method` decorator is in wrong place - Subsumes lambda functions after replacing +- Add working loopnest test ## 8.0.1 (2024-10-24) diff --git a/python/egglog/exp/array_api_loopnest.py b/python/egglog/exp/array_api_loopnest.py index c12c67e..c35fda5 100644 --- a/python/egglog/exp/array_api_loopnest.py +++ b/python/egglog/exp/array_api_loopnest.py @@ -12,6 +12,8 @@ from egglog import * from egglog.exp.array_api import * +__all__ = ["LoopNestAPI", "OptionalLoopNestAPI", "ShapeAPI"] + class ShapeAPI(Expr): def __init__(self, dims: TupleIntLike) -> None: ... @@ -105,76 +107,3 @@ def _loopnest_api_ruleset( yield rewrite(lna.indices, subsume=True).to( tuple_tuple_int_product(tuple_int_map_tuple_int(lna.get_dims(), TupleInt.range)) ) - - -@function(ruleset=array_api_ruleset, subsume=True) -def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: - # peel off the outer shape for result array - outshape = ShapeAPI(X.shape).deselect(axis).to_tuple() - # get only the inner shape for reduction - reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple() - - return NDArray( - outshape, - X.dtype, - lambda k: sqrt( - LoopNestAPI.from_tuple(reduce_axis) - .unwrap() - .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0) - ).to_value(), - ) - - -# %% -# egraph = EGraph(save_egglog_string=True) - -# egraph.register(val.shape) -# egraph.run(array_api_ruleset.saturate()) -# egraph.extract_multiple(val.shape, 10) - -# %% - -X = NDArray.var("X") -assume_shape(X, (3, 2, 3, 4)) -val = linalg_norm(X, (0, 1)) -egraph = EGraph() -x = egraph.let("x", val.shape[2]) -# egraph.display(n_inline_leaves=0) -# egraph.extract(x) -# egraph.saturate(array_api_ruleset, expr=x, split_functions=[Int, TRUE, FALSE], n_inline_leaves=0) -# egraph.run(array_api_ruleset.saturate()) -# egraph.extract(x) -# egraph.display() - - -# %% - -# x = xs[-2] -# # %% -# decls = x.__egg_decls__ -# # RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1]) - -# # %% -# # x.__egg_typed_expr__.expr.args[1].expr.args[1] # %% - -# # %% -# # egraph.extract(RuntimeExpr.__from_values__(x.__egg_decls__, x.__egg_typed_expr__.expr.args[1].expr.args[1])) - - -# from egglog import pretty - -# decl = ( -# x.__egg_typed_expr__.expr.args[1] -# .expr.args[2] -# .expr.args[0] -# .expr.args[1] -# .expr.call.args[0] -# .expr.call.args[0] -# .expr.call.args[0] -# ) - -# # pprint.pprint(decl) - -# print(pretty.pretty_decl(decls, decl.expr)) - -# # %% diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index b6daa22..b48fb7b 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -8,8 +8,10 @@ from sklearn import config_context, datasets from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from egglog.egraph import set_current_ruleset from egglog.exp.array_api import * from egglog.exp.array_api_jit import jit +from egglog.exp.array_api_loopnest import * from egglog.exp.array_api_numba import array_api_numba_schedule from egglog.exp.array_api_program_gen import * @@ -68,6 +70,41 @@ def test_reshape_vec_noop(): egraph.check(eq(res).to(x)) +def test_filter(): + with set_current_ruleset(array_api_ruleset): + x = TupleInt.range(5).filter(lambda i: i < 2).length() + check_eq(x, Int(2), array_api_schedule) + + +@function(ruleset=array_api_ruleset, subsume=True) +def linalg_norm(X: NDArray, axis: TupleIntLike) -> NDArray: + # peel off the outer shape for result array + outshape = ShapeAPI(X.shape).deselect(axis).to_tuple() + # get only the inner shape for reduction + reduce_axis = ShapeAPI(X.shape).select(axis).to_tuple() + + return NDArray( + outshape, + X.dtype, + lambda k: sqrt( + LoopNestAPI.from_tuple(reduce_axis) + .unwrap() + .fold(lambda carry, i: carry + real(conj(x := X[i + k]) * x), init=0.0) + ).to_value(), + ) + + +class TestLoopNest: + def test_shape(self): + X = NDArray.var("X") + assume_shape(X, (3, 2, 3, 4)) + val = linalg_norm(X, (0, 1)) + + check_eq(val.shape.length(), Int(2), array_api_schedule) + check_eq(val.shape[0], Int(3), array_api_schedule) + check_eq(val.shape[1], Int(4), array_api_schedule) + + # This test happens in different steps. Each will be benchmarked and saved as a snapshot. # The next step will load the old snapshot and run their test on it. @@ -80,7 +117,6 @@ def run_lda(x, y): iris = datasets.load_iris() X_np, y_np = (iris.data, iris.target) -res_np = run_lda(X_np, y_np) def _load_py_snapshot(fn: Callable, var: str | None = None) -> Any: @@ -165,7 +201,7 @@ def test_source_optimized(self, snapshot_py, benchmark): optimized_expr = simplify_lda(egraph, expr) fn_program = ndarray_function_two(optimized_expr, NDArray.var("X"), NDArray.var("y")) py_object = benchmark(load_source, fn_program, egraph) - assert np.allclose(py_object(X_np, y_np), res_np) + assert np.allclose(py_object(X_np, y_np), run_lda(X_np, y_np)) assert egraph.eval(fn_program.statements) == snapshot_py @pytest.mark.parametrize( @@ -180,7 +216,7 @@ def test_source_optimized(self, snapshot_py, benchmark): ) def test_execution(self, fn, benchmark): # warmup once for numba - assert np.allclose(res_np, fn(X_np, y_np)) + assert np.allclose(run_lda(X_np, y_np), fn(X_np, y_np)) benchmark(fn, X_np, y_np)