Skip to content

Commit

Permalink
Use union_instantiations in natlit rewrites
Browse files Browse the repository at this point in the history
Add graphviz option for debugging
  • Loading branch information
marcusrossel committed Feb 28, 2024
1 parent ccca8af commit 67e509f
Show file tree
Hide file tree
Showing 11 changed files with 72 additions and 47 deletions.
3 changes: 0 additions & 3 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
/build
/lakefile.olean
/lake-manifest.json
/Rust/Cargo.lock
/Rust/target
/.lake
10 changes: 7 additions & 3 deletions C/ffi.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ extern egg_result c_egg_explain_congr(
rewrite* rws,
size_t rws_count,
rust_bool optimize_expl,
rust_bool gen_nat_lit_rws
rust_bool gen_nat_lit_rws,
const char* viz_path
);

// `init`: string
Expand All @@ -44,6 +45,7 @@ extern egg_result c_egg_explain_congr(
// `rw_dirs`: array of uint8_t containing the directions (cf. `rw_dir`) of rewrites
// `optimize_expl`: boolean indicating whether egg should try to shorten its explanations
// `gen_nat_lit_rws`: boolean indicating whether egg should use additional rewrites to convert between nat-lits and `Nat.zero`/`Nat.succ`
// `viz_path`: string
// return value: string explaining the rewrite sequence
lean_obj_res lean_egg_explain_congr(
lean_obj_arg init,
Expand All @@ -53,7 +55,8 @@ lean_obj_res lean_egg_explain_congr(
lean_obj_arg rw_rhss,
lean_obj_arg rw_dirs,
lean_bool optimize_expl,
lean_bool gen_nat_lit_rws
lean_bool gen_nat_lit_rws,
lean_obj_arg viz_path
) {
const char* init_c_str = lean_string_cstr(init);
const char* goal_c_str = lean_string_cstr(goal);
Expand All @@ -75,8 +78,9 @@ lean_obj_res lean_egg_explain_congr(
}
rust_bool opt_expl = lean_bool_to_rust(optimize_expl);
rust_bool nat_lit_rws = lean_bool_to_rust(gen_nat_lit_rws);
const char* viz_path_c_str = lean_string_cstr(viz_path);

egg_result result = c_egg_explain_congr(init_c_str, goal_c_str, rws, rws_count, opt_expl, nat_lit_rws);
egg_result result = c_egg_explain_congr(init_c_str, goal_c_str, rws, rws_count, opt_expl, nat_lit_rws, viz_path_c_str);
free(rws);

return lean_mk_string(result.expl);
Expand Down
15 changes: 8 additions & 7 deletions Lean/Egg/Core/Config.lean
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
namespace Egg.Config

inductive ExitPoint
| none
| beforeEqSat
| beforeProof
deriving BEq

structure Encoding where
eraseProofs := true
eraseLambdaDomains := false
Expand All @@ -22,8 +16,15 @@ structure Backend where
optimizeExpl := false
deriving BEq

inductive Debug.ExitPoint
| none
| beforeEqSat
| beforeProof
deriving BEq

structure Debug where
exitPoint := Config.ExitPoint.none
exitPoint : Debug.ExitPoint := .none
vizPath : Option String := none
deriving BEq

end Config
Expand Down
4 changes: 2 additions & 2 deletions Lean/Egg/Core/Gen/TcProjs.lean
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ private structure TcProj where

abbrev TcProjIndex := HashMap TcProj Source

private def TcProj.reductionRewrite (proj : TcProj) (src : Source) : MetaM (Option Rewrite) := do
private def TcProj.reductionRewrite? (proj : TcProj) (src : Source) : MetaM (Option Rewrite) := do
let app := mkAppN (.const proj.const proj.lvls) proj.args
let reduced ← withReducibleAndInstances do Expr.eta <$> reduceAll app
if app == reduced then return none
Expand Down Expand Up @@ -67,4 +67,4 @@ def genTcProjReductions (targets : Array (Congr × Source)) : MetaM Rewrites :=
for (cgr, src) in targets do
projs ← tcProjs cgr.lhs src .left projs
projs ← tcProjs cgr.rhs src .right projs
projs.toArray.filterMapM fun (proj, src) => proj.reductionRewrite src
projs.toArray.filterMapM fun (proj, src) => proj.reductionRewrite? src
8 changes: 6 additions & 2 deletions Lean/Egg/Core/Request.lean
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ namespace Egg
@[extern "lean_egg_explain_congr"]
private opaque explainCongr
(lhs rhs : Expression) (rwNames : Array String) (lhss rhss : Array Expression)
(dirs : Array Rewrite.Directions) (optimizeExpl : Bool) (genNatLitRws : Bool) : String
(dirs : Array Rewrite.Directions) (optimizeExpl : Bool) (genNatLitRws : Bool) (vizPath : String)
: String

structure Request where
private mk ::
Expand All @@ -20,6 +21,7 @@ structure Request where
rws : Rewrites.Encoded
optimizeExpl : Bool
genNatLitRws : Bool
vizPath : String

namespace Request

Expand All @@ -30,7 +32,9 @@ def encoding (goal : Congr) (rws : Rewrites) (cfg : Config) : MetaM Request := d
rws := ← rws.encode cfg.toEncoding
optimizeExpl := cfg.optimizeExpl
genNatLitRws := cfg.genNatLitRws
vizPath := cfg.vizPath.getD ""
}

def run (r : Request) : Explanation.Raw :=
explainCongr r.lhs r.rhs r.rws.names r.rws.lhss r.rws.rhss r.rws.dirs r.optimizeExpl r.genNatLitRws
explainCongr
r.lhs r.rhs r.rws.names r.rws.lhss r.rws.rhss r.rws.dirs r.optimizeExpl r.genNatLitRws r.vizPath
12 changes: 5 additions & 7 deletions Lean/Egg/Tests/Groups.lean
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,14 @@ theorem inv_add_cancel_left : -a + (a + b) = b := by
theorem add_inv_cancel_left : a + (-a + b) = b := by
egg [add_assoc, zero_add, add_zero, add_left_inv, add_right_inv]

-- TODO: The test cases below should be fixed by explosion.
-- TODO: This test case should be fixed by typeclass specialization.

theorem inv_add : -(a + b) = -b + -a := by
theorem zero_inv : -(0 : G) = 0 := by
sorry -- egg [add_assoc, zero_add, add_zero, add_left_inv, add_right_inv]

-- Proof:
-- simp [Neg.neg, OfNat.ofNat]
-- rw [←add_zero (a := neg zero)]
-- rw [add_left_inv]
theorem zero_inv : -(0 : G) = 0 := by
-- TODO: The test cases below should be fixed by explosion.

theorem inv_add : -(a + b) = -b + -a := by
sorry -- egg [add_assoc, zero_add, add_zero, add_left_inv, add_right_inv]

theorem inv_inv : -(-a) = a := by
Expand Down
12 changes: 12 additions & 0 deletions Lean/Egg/Tests/NatLit.lean
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,16 @@ example : Int.ofNat (Nat.succ 1) = Int.ofNat (Nat.succ (Nat.succ Nat.zero)) := b
example (h : ∀ n, Nat.succ n = n + 1) : 1 = Nat.zero + 1 := by
egg [h]

elab "app" n:num fn:ident arg:term : term => open Lean.Elab.Term in do
let fn ← elabTerm fn none
let rec go (n : Nat) := if n = 0 then elabTerm arg none else return .app fn <| ← go (n - 1)
go n.getNat

-- Note: If we go to `61`, egg can't handle it anymore.
example : (app 60 Nat.succ (nat_lit 0)) = (nat_lit 60) := by egg

-- Note: This produces a gigantic proof.
example (f : Nat → Nat) (h : ∀ x, f x = x.succ) : 30 = app 30 f 0 := by
egg [h]

-- TODO: Add more tests involving rewrites with Nat.succ or Nat.zero.
6 changes: 4 additions & 2 deletions Lean/Egg/Tests/WIP.lean
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ example : (∀ α (l : List α), l.length = l.length) ↔ (∀ α (l : List α),

-- For rewrites involving dependent arguments, we can easily get an incorrect motive. E.g. when
-- rewriting the condition in ite without chaning the type class instance:
set_option trace.egg true in
example : (if 0 = 0 then 0 else 1) = 0 := by
have : (0 = 0) = True := eq_self 0
rw [this]
have h1 : (0 = 0) = True := eq_self 0
have h2 : 0 = 0 := rfl
egg (config := { optimizeExpl := true }) [h1, h2, ite_congr, if_true]

-- For typeclass arguments we might be able to work around this by the following:
-- When a rewrite is applied to a term containing a typeclass argument (which we might be able to
Expand Down
5 changes: 4 additions & 1 deletion Rust/src/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::result::*;
use crate::lean_expr::*;
use crate::nat_lit::*;

pub fn explain_congr(init: String, goal: String, rws: Vec<Rewrite<LeanExpr, NatLitAnalysis>>, optimize_expl: bool, gen_nat_lit_rws: bool) -> Res<String> {
pub fn explain_congr(init: String, goal: String, rws: Vec<Rewrite<LeanExpr, NatLitAnalysis>>, optimize_expl: bool, gen_nat_lit_rws: bool, viz_path: Option<String>) -> Res<String> {
let mut egraph: EGraph<LeanExpr, NatLitAnalysis> = Default::default();
egraph = egraph.with_explanations_enabled();
if !optimize_expl { egraph = egraph.without_explanation_length_optimization() }
Expand All @@ -16,6 +16,9 @@ pub fn explain_congr(init: String, goal: String, rws: Vec<Rewrite<LeanExpr, NatL
let mut runner = Runner::default()
.with_egraph(egraph)
.with_hook(move |runner| {
if let Some(path) = &viz_path {
runner.egraph.dot().to_dot(format!("{}/{}.dot", path, runner.iterations.len())).unwrap();
}
if runner.egraph.find(init_id) == runner.egraph.find(goal_id) {
Err("search complete".to_string())
} else {
Expand Down
11 changes: 8 additions & 3 deletions Rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub extern "C" fn c_egg_explain_congr(
rws_ptr: *const CRewrite,
rws_count: usize,
optimize_expl: bool,
gen_nat_lit_rws: bool
gen_nat_lit_rws: bool,
viz_path_ptr: *const c_char
) -> EggResult {
// Cf. https://doc.rust-lang.org/stable/std/ffi/struct.CStr.html#examples
let init_c_str = unsafe { CStr::from_ptr(init_str_ptr) };
Expand All @@ -88,7 +89,7 @@ pub extern "C" fn c_egg_explain_congr(
let goal = String::from_utf8_lossy(goal_c_str.to_bytes()).to_string();
assert!(rws_ptr != null());
let c_rws = unsafe { std::slice::from_raw_parts(rws_ptr, rws_count) };

// Note: The `into_raw`s below are important, as otherwise Rust deallocates the string.
// TODO: I think this is a memory leak right now.

Expand All @@ -99,7 +100,11 @@ pub extern "C" fn c_egg_explain_congr(
}
let rws = rws.unwrap();

let expl = explain_congr(init, goal, rws, optimize_expl, gen_nat_lit_rws);
let viz_path_c_str = unsafe { CStr::from_ptr(viz_path_ptr) };
let raw_viz_path = String::from_utf8_lossy(viz_path_c_str.to_bytes()).to_string();
let viz_path = if raw_viz_path.is_empty() { None } else { Some(raw_viz_path) };

let expl = explain_congr(init, goal, rws, optimize_expl, gen_nat_lit_rws, viz_path);
if let Err(expl_err) = expl {
let rws_err_c_str = CString::new(expl_err.to_string()).expect("conversion of error message to C-string failed");
return EggResult { success: false, expl: rws_err_c_str.into_raw() }
Expand Down
33 changes: 16 additions & 17 deletions Rust/src/nat_lit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ impl Analysis<LeanExpr> for NatLitAnalysis {
// This prefers `Some` value over `None`. Note that if `to` and `from` are both present,
// then they should have the same value as otherwise the merging of their e-classes indicates
// an invalid rewrite.

// TODO: We can't activate this assertion, because then egg can crashs from unsound rewrites (cf. `Tests/Soundness.lean`).
// Is there a way to gracefully fail?
//
// if let (Some(t), Some(f)) = (*to, from) { assert_eq!(t, f) }

egg::merge_max(to, from)
}

Expand All @@ -30,17 +36,13 @@ struct ToSucc {

impl Applier<LeanExpr, NatLitAnalysis> for ToSucc {

fn apply_one(&self, egraph: &mut EGraph<LeanExpr, NatLitAnalysis>, matched_id: Id, subst: &Subst, _: Option<&PatternAst<LeanExpr>>, _: Symbol) -> Vec<Id> {
fn apply_one(&self, egraph: &mut EGraph<LeanExpr, NatLitAnalysis>, _: Id, subst: &Subst, ast: Option<&PatternAst<LeanExpr>>, rule: Symbol) -> Vec<Id> {
if let Some(lit_val) = egraph[subst[self.lit_val]].data {
if lit_val > 0 {
let pred = egraph.add(LeanExpr::Nat(lit_val - 1));
let pred_lit = egraph.add(LeanExpr::Lit(pred));
let succ_name = egraph.add(LeanExpr::Str("Nat.succ".to_string()));
let succ_const = egraph.add(LeanExpr::Const(Box::new([succ_name])));
let app_succ_pred = egraph.add(LeanExpr::App([succ_const, pred_lit]));
if egraph.union(matched_id, app_succ_pred) {
return vec![app_succ_pred]
}
let ast = ast.unwrap(); // The `ast` is present when explanations are enabled, which we always do.
let res = format!("(app (const Nat.succ) (lit {}))", lit_val - 1).parse().unwrap();
let (id, _) = egraph.union_instantiations(ast, &res, subst, rule);
return vec![id]
}
}
vec![]
Expand All @@ -53,20 +55,17 @@ struct OfSucc {

impl Applier<LeanExpr, NatLitAnalysis> for OfSucc {

fn apply_one(&self, egraph: &mut EGraph<LeanExpr, NatLitAnalysis>, matched_id: Id, subst: &Subst, _: Option<&PatternAst<LeanExpr>>, _: Symbol) -> Vec<Id> {
fn apply_one(&self, egraph: &mut EGraph<LeanExpr, NatLitAnalysis>, _: Id, subst: &Subst, ast: Option<&PatternAst<LeanExpr>>, rule: Symbol) -> Vec<Id> {
if let Some(lit_val) = egraph[subst[self.lit_val]].data {
let succ = egraph.add(LeanExpr::Nat(lit_val + 1));
let succ_lit = egraph.add(LeanExpr::Lit(succ));
if egraph.union(matched_id, succ_lit) {
return vec![succ_lit]
}
let ast = ast.unwrap(); // The `ast` is present when explanations are enabled, which we always do.
let res = format!("(lit {})", lit_val + 1).parse().unwrap();
let (id, _) = egraph.union_instantiations(ast, &res, subst, rule);
return vec![id]
}
vec![]
}
}

// TODO: Mention in the thesis that this uses dynamic rewrites, which is also why we can't implement it
// as a `Egg.Rewrite` in Lean.
pub fn nat_lit_rws() -> Vec<Rewrite<LeanExpr, NatLitAnalysis>> {
let mut rws = vec![];
rws.append(&mut rewrite!("!z"; "(lit 0)" <=> "(const Nat.zero)"));
Expand Down

0 comments on commit 67e509f

Please sign in to comment.