From 2e2456927fd37a6aa5dffb718d77b269f5e1d9e4 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 28 Sep 2023 09:57:14 -0400 Subject: [PATCH] Real working version! --- python/egglog/exp/program_gen.py | 13 ++++++++----- .../test_program_gen/test_to_string.py | 6 ++++-- python/tests/test_program_gen.py | 8 ++++---- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index 16678e13..e711abb2 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -165,16 +165,19 @@ def _compile( # of the two program_add = p1 + p2 - # Set parents - yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p), set_(p2.parent).to(p)) + # Set parent of p1 + yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p)) - # Compile p1, if p1 parent set + # Compile p1, if p1 parent equal yield rule(eq(p).to(program_add), p.compile(i), eq(p1.parent).to(program_add)).then(p1.compile(i)) - # Compile p2, if p1 parent not set + # Set parent of p2, once p1 compiled + yield rule(eq(p).to(program_add), p1.next_sym).then(set_(p2.parent).to(p)) + + # Compile p2, if p1 parent not equal yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p).then(p2.compile(i)) - # Compile p2, if p1 parent set + # Compile p2, if p1 parent eqal yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym)).then(p2.compile(i)) # Set p expr to join of p1 and p2 diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py index c2148732..cc9697b6 100644 --- a/python/tests/__snapshots__/test_program_gen/test_to_string.py +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -1,4 +1,6 @@ -_0 = -(x) +_0 = -x assert _0 > 0 _1 = _0 + x -_1 +_2 = _1 + 2 +_3 = _2 + _1 +_3 diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 739374dc..450d93fc 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -51,19 +51,19 @@ def _rules( yield rewrite(Math(i).program).to(Program(i.to_string())) yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) - yield rewrite((-y).program).to(Program("-(") + y.program + ")") + yield rewrite((-y).program).to(Program("-") + y.program) assigned_x = x.program.assign() yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) first = assume_pos(-Math.var("x")) + Math.var("x") with egraph: - y = first + y = first + Math(2) + first egraph.register(y.program) - egraph.run(10) + egraph.run(100) p = egraph.extract(y.program) egraph.register(p) egraph.register(p.compile()) - egraph.run(40) + egraph.run(100) # egraph.display(n_inline_leaves=1) e = egraph.load_object(egraph.extract(PyObject.from_string(p.expr))) stmts = egraph.load_object(egraph.extract(PyObject.from_string(p.statements)))