From f551d041102c596d83a3c2aade374a4bfc797a0a Mon Sep 17 00:00:00 2001 From: Alex Fischman Date: Mon, 3 Jun 2024 16:39:05 -0700 Subject: [PATCH] Remove LoopContextUnionsAnd, just pass ContextCache around (fixes subst bug) --- dag_in_context/src/add_context.rs | 107 +++++++----------- dag_in_context/src/interval_analysis.rs | 57 +++++----- dag_in_context/src/lib.rs | 28 ++--- .../src/optimizations/function_inlining.rs | 20 ++-- .../src/optimizations/loop_unroll.rs | 2 +- .../src/optimizations/passthrough.rs | 13 +-- .../src/optimizations/switch_rewrites.rs | 68 ++++------- dag_in_context/src/schema_helpers.rs | 90 ++++++++------- src/rvsdg/to_dag.rs | 12 +- src/util.rs | 36 +++--- 10 files changed, 184 insertions(+), 249 deletions(-) diff --git a/dag_in_context/src/add_context.rs b/dag_in_context/src/add_context.rs index 90602699..946bee52 100644 --- a/dag_in_context/src/add_context.rs +++ b/dag_in_context/src/add_context.rs @@ -16,7 +16,9 @@ use crate::{ pub struct ContextCache { with_ctx: HashMap<(*const Expr, AssumptionRef), RcExpr>, symbol_gen: HashMap<(*const Expr, AssumptionRef), String>, - loop_contexts: LoopContextUnionsAnd<()>, + /// Information for generating the unions for loop contexts + placeholder: usize, + loop_context_unions: Vec<(Assumption, Assumption)>, /// When true, don't add context- instead, make fresh query variables /// and put these in place of context symbolic_ctx: bool, @@ -46,7 +48,8 @@ impl ContextCache { ContextCache { with_ctx: HashMap::new(), symbol_gen: HashMap::new(), - loop_contexts: LoopContextUnionsAnd::new(), + placeholder: 0, + loop_context_unions: Vec::new(), symbolic_ctx: false, dummy_ctx: false, } @@ -56,7 +59,8 @@ impl ContextCache { ContextCache { with_ctx: HashMap::new(), symbol_gen: HashMap::new(), - loop_contexts: LoopContextUnionsAnd::new(), + placeholder: 0, + loop_context_unions: Vec::new(), symbolic_ctx: true, dummy_ctx: false, } @@ -66,61 +70,17 @@ impl ContextCache { ContextCache { with_ctx: HashMap::new(), symbol_gen: HashMap::new(), - loop_contexts: LoopContextUnionsAnd::new(), + placeholder: 0, + loop_context_unions: Vec::new(), symbolic_ctx: false, dummy_ctx: true, } } - pub fn get_unions(&self) -> String { - self.loop_contexts.get_unions() - } -} - -// not a tuple to prevent auto-impls of Clone, Debug, etc. -pub struct LoopContextUnionsAnd { - var: usize, - // marked as public but you probably want `get_unions` - pub unions: Vec<(Assumption, Assumption)>, - pub value: T, -} - -impl Default for LoopContextUnionsAnd<()> { - fn default() -> Self { - Self::new() - } -} - -impl LoopContextUnionsAnd<()> { - pub fn new() -> LoopContextUnionsAnd<()> { - LoopContextUnionsAnd { - var: 0, - unions: Vec::new(), - value: (), - } - } -} - -impl LoopContextUnionsAnd { - fn new_placeholder(&mut self) -> Assumption { - let placeholder = Assumption::InFunc(format!(" loop_ctx_{}", self.var)); - self.var += 1; - placeholder - } - - pub fn swap_value(self, value: S) -> (LoopContextUnionsAnd, T) { - let LoopContextUnionsAnd { - var, - unions, - value: old, - } = self; - (LoopContextUnionsAnd { var, unions, value }, old) - } - pub fn get_unions(&self) -> String { use std::fmt::Write; - self.unions + self.loop_context_unions .iter() .fold(String::new(), |mut output, (a, b)| { let _ = writeln!(output, "(union {a} {b})"); @@ -134,7 +94,7 @@ impl LoopContextUnionsAnd { tree_state: &mut TreeToEgglog, term_cache: &mut HashMap, ) -> String { - self.unions + self.loop_context_unions .iter() .map(|(a, b)| { let internal_a = a.to_egglog_internal(tree_state); @@ -158,36 +118,47 @@ impl LoopContextUnionsAnd { .collect::>() .join("\n") } + + pub fn new_placeholder(&mut self) -> Assumption { + let placeholder = Assumption::InFunc(format!(" loop_ctx_{}", self.placeholder)); + self.placeholder += 1; + placeholder + } + + pub fn push_loop_context_union(&mut self, a: Assumption, b: Assumption) { + self.loop_context_unions.push((a, b)) + } } impl TreeProgram { - pub fn add_context(&self) -> LoopContextUnionsAnd { + pub fn add_context(&self) -> (TreeProgram, ContextCache) { self.add_context_internal(Expr::func_get_ctx, ContextCache::new()) } /// add stand-in variables for all the contexts in the program /// useful for testing if you don't care about context in the test - pub fn add_symbolic_ctx(&self) -> LoopContextUnionsAnd { + pub fn add_symbolic_ctx(&self) -> TreeProgram { self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_symbolic_ctx()) + .0 } - pub fn add_dummy_ctx(&self) -> LoopContextUnionsAnd { + pub fn add_dummy_ctx(&self) -> TreeProgram { self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_dummy_ctx()) + .0 } fn add_context_internal( &self, func: impl Fn(&RcExpr) -> Assumption, mut cache: ContextCache, - ) -> LoopContextUnionsAnd { + ) -> (TreeProgram, ContextCache) { let entry = self.entry.add_ctx_with_cache(func(&self.entry), &mut cache); let functions = self .functions .iter() .map(|f| f.add_ctx_with_cache(func(f), &mut cache)) .collect(); - let value = TreeProgram { functions, entry }; - cache.loop_contexts.swap_value(value).0 + (TreeProgram { functions, entry }, cache) } } @@ -199,7 +170,7 @@ impl Expr { Assumption::InFunc(name.clone()) } - pub fn func_add_ctx(self: &RcExpr) -> LoopContextUnionsAnd { + pub fn func_add_ctx(self: &RcExpr) -> (RcExpr, ContextCache) { let Expr::Function(name, arg_ty, ret_ty, body) = self.as_ref() else { panic!("Expected Function, got {:?}", self); }; @@ -210,25 +181,23 @@ impl Expr { ret_ty.clone(), body.add_ctx_with_cache(self.func_get_ctx(), &mut cache), )); - cache.loop_contexts.swap_value(value).0 + (value, cache) } - pub fn add_dummy_ctx(self: &RcExpr) -> LoopContextUnionsAnd { + pub fn add_dummy_ctx(self: &RcExpr) -> RcExpr { let mut cache = ContextCache::new_dummy_ctx(); - let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache); - cache.loop_contexts.swap_value(value).0 + self.add_ctx_with_cache(Assumption::dummy(), &mut cache) } - pub fn add_symbolic_ctx(self: &RcExpr) -> LoopContextUnionsAnd { + pub fn add_symbolic_ctx(self: &RcExpr) -> RcExpr { let mut cache = ContextCache::new_symbolic_ctx(); - let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache); - cache.loop_contexts.swap_value(value).0 + self.add_ctx_with_cache(Assumption::dummy(), &mut cache) } - pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> LoopContextUnionsAnd { + pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> (RcExpr, ContextCache) { let mut cache = ContextCache::new(); let value = self.add_ctx_with_cache(current_ctx, &mut cache); - cache.loop_contexts.swap_value(value).0 + (value, cache) } pub fn add_ctx_with_cache( @@ -259,14 +228,14 @@ impl Expr { Expr::Arg(ty, _oldctx) => RcExpr::new(Expr::Arg(ty.clone(), context_to_add)), // create new contexts for let, loop, and if Expr::DoWhile(inputs, pred_and_body) => { - let placeholder = cache.loop_contexts.new_placeholder(); + let placeholder = cache.new_placeholder(); let new_inputs = inputs.add_ctx_with_cache(current_ctx.clone(), cache); let new_pred_and_body = pred_and_body.add_ctx_with_cache(placeholder.clone(), cache); let new_ctx = Assumption::InLoop(new_inputs.clone(), new_pred_and_body.clone()); - cache.loop_contexts.unions.push((placeholder, new_ctx)); + cache.push_loop_context_union(placeholder, new_ctx); RcExpr::new(Expr::DoWhile(new_inputs, new_pred_and_body)) } diff --git a/dag_in_context/src/interval_analysis.rs b/dag_in_context/src/interval_analysis.rs index d6fa0d50..62fc6d96 100644 --- a/dag_in_context/src/interval_analysis.rs +++ b/dag_in_context/src/interval_analysis.rs @@ -79,23 +79,20 @@ fn constant_fold() -> crate::Result { #[test] fn test_add_constant_fold() -> crate::Result { use crate::ast::*; - let expr = add(int(1), int(2)) + let (expr, cache) = add(int(1), int(2)) .with_arg_types(emptyt(), base(intt())) .add_ctx(Assumption::dummy()); - let expr2 = int_ty(3, emptyt()).add_ctx(Assumption::dummy()); + let (expr2, cache2) = int_ty(3, emptyt()).add_ctx(Assumption::dummy()); egglog_test( - &format!("{}\n{}", expr.value, expr.get_unions()), + &format!("{expr}\n{}", cache.get_unions()), &format!( - "{}\n{}\n(check (= {} {}))", - expr2.value, - expr2.get_unions(), - expr.value, - expr2.value + "{expr2}\n{}\n(check (= {expr} {expr2}))", + cache2.get_unions(), ), vec![ - expr.value.to_program(emptyt(), base(intt())), - expr2.value.to_program(emptyt(), base(intt())), + expr.to_program(emptyt(), base(intt())), + expr2.to_program(emptyt(), base(intt())), ], val_empty(), intv(3), @@ -198,13 +195,13 @@ fn context_if() -> crate::Result { let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types(); let prog = f.to_program(base(intt()), base(boolt())); - let with_context = prog.add_context(); - let term = with_context.value.entry.func_body().unwrap(); + let (with_context, cache) = prog.add_context(); + let term = with_context.entry.func_body().unwrap(); egglog_test( - &format!("{}\n{}", with_context.value, with_context.get_unions()), + &format!("{with_context}\n{}", cache.get_unions()), &format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"), - vec![with_context.value], + vec![with_context], intv(4), val_bool(false), vec![], @@ -216,13 +213,13 @@ fn simple_less_than() -> crate::Result { // 0 <= input let cond = less_eq(int_ty(0, base(intt())), int(-1)); let prog = program!(function("main", base(intt()), base(boolt()), cond.clone()),); - let with_context = prog.add_context(); - let term = with_context.value.entry.func_body().unwrap(); + let (with_context, cache) = prog.add_context(); + let term = with_context.entry.func_body().unwrap(); egglog_test( - &format!("{}\n{}", with_context.value, with_context.get_unions()), + &format!("{with_context}\n{}", cache.get_unions()), &format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"), - vec![with_context.value], + vec![with_context], intv(4), val_bool(false), vec![], @@ -244,16 +241,16 @@ fn context_if_rev() -> crate::Result { let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types(); let prog = f.to_program(base(intt()), base(boolt())); - let with_context = prog.add_context(); - let term = with_context.value.entry.func_body().unwrap(); + let (with_context, cache) = prog.add_context(); + let term = with_context.entry.func_body().unwrap(); egglog_test( - &format!("{}\n{}", with_context.value, with_context.get_unions()), + &format!("{with_context}\n{}", cache.get_unions()), &format!( " (check (= {term} (Const (Bool false) (Base (IntT)) (InFunc \"main\"))))" ), - vec![with_context.value], + vec![with_context], intv(4), val_bool(false), vec![], @@ -303,28 +300,26 @@ fn context_if_with_state() -> crate::Result { body.clone(), ) .func_with_arg_types(); - let prog = f + let (prog, cache) = f .to_program(input_type.clone(), output_type.clone()) .with_arg_types() .add_context(); - let body_with_ctx = prog.value.entry.func_body().unwrap(); + let body_with_ctx = prog.entry.func_body().unwrap(); - let expected = single(tprint( + let (expected, expected_cache) = single(tprint( ttrue_ty(input_type.clone()), get(input_arg.clone(), 1), )) .add_ctx(Assumption::InFunc("main".to_string())); egglog_test( - &format!("{}\n{}", prog.value, prog.get_unions()), + &format!("{prog}\n{}", cache.get_unions()), &format!( - "{}\n{}\n(check (= {} {body_with_ctx}))", - expected.value, - expected.get_unions(), - expected.value + "{expected}\n{}\n(check (= {expected} {body_with_ctx}))", + expected_cache.get_unions(), ), - vec![prog.value], + vec![prog], val_vec(vec![intv(4), statev()]), val_vec(vec![statev()]), vec!["true".to_string()], diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 3682d2b4..c3a470e9 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -8,7 +8,7 @@ use std::fmt::Write; use to_egglog::TreeToEgglog; use crate::{ - add_context::LoopContextUnionsAnd, dag2svg::tree_to_svg, interpreter::interpret_dag_prog, + add_context::ContextCache, dag2svg::tree_to_svg, interpreter::interpret_dag_prog, optimizations::function_inlining, schedule::mk_schedule, }; @@ -114,8 +114,7 @@ pub(crate) fn print_with_intermediate_vars(termdag: &TermDag, term: Term) -> Str } // It is expected that program has context added -pub fn build_program(program: LoopContextUnionsAnd, optimize: bool) -> String { - let (mut unions, program) = program.swap_value(()); +pub fn build_program(program: &TreeProgram, cache: &mut ContextCache, optimize: bool) -> String { let mut printed = String::new(); // Create a global cache for generating intermediate variables @@ -128,9 +127,9 @@ pub fn build_program(program: LoopContextUnionsAnd, optimize: bool) } else { function_inlining::print_function_inlining_pairs( function_inlining::function_inlining_pairs( - &program, + program, config::FUNCTION_INLINING_ITERATIONS, - &mut unions, + cache, ), &mut printed, &mut tree_state, @@ -144,7 +143,7 @@ pub fn build_program(program: LoopContextUnionsAnd, optimize: bool) print_with_intermediate_helper(&tree_state.termdag, term, &mut term_cache, &mut printed); let loop_context_unions = - unions.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache); + cache.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache); let prologue = prologue(); @@ -187,10 +186,7 @@ pub fn are_progs_eq(program1: TreeProgram, program2: TreeProgram) -> bool { /// Checks that the extracted program is the same as the input program. pub fn check_roundtrip_egraph(program: &TreeProgram) { let mut termdag = egglog::TermDag::default(); - let egglog_prog = build_program( - LoopContextUnionsAnd::new().swap_value(program.clone()).0, - false, - ); + let egglog_prog = build_program(program, &mut ContextCache::new(), false); log::info!("Running egglog program..."); let mut egraph = egglog::EGraph::default(); egraph.parse_and_run_program(&egglog_prog).unwrap(); @@ -204,8 +200,8 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) { DefaultCostModel, ); - let original_with_ctx = program.add_dummy_ctx().value; - let res_with_ctx = res.add_dummy_ctx().value; + let original_with_ctx = program.add_dummy_ctx(); + let res_with_ctx = res.add_dummy_ctx(); if !are_progs_eq(original_with_ctx.clone(), res_with_ctx.clone()) { eprintln!("Original program: {}", tree_to_svg(&original_with_ctx)); @@ -216,10 +212,10 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) { // It is expected that program has context added pub fn optimize( - program: LoopContextUnionsAnd, + program: &TreeProgram, + cache: &mut ContextCache, ) -> std::result::Result { - let original_program = program.value.clone(); - let egglog_prog = build_program(program, true); + let egglog_prog = build_program(program, cache, true); log::info!("Running egglog program..."); let mut egraph = egglog::EGraph::default(); egraph.parse_and_run_program(&egglog_prog)?; @@ -227,7 +223,7 @@ pub fn optimize( let (serialized, unextractables) = serialized_egraph(egraph); let mut termdag = egglog::TermDag::default(); let (_res_cost, res) = extract( - &original_program, + program, serialized, unextractables, &mut termdag, diff --git a/dag_in_context/src/optimizations/function_inlining.rs b/dag_in_context/src/optimizations/function_inlining.rs index 5b68b7ff..7f0e4eaa 100644 --- a/dag_in_context/src/optimizations/function_inlining.rs +++ b/dag_in_context/src/optimizations/function_inlining.rs @@ -7,7 +7,7 @@ use std::{ use egglog::Term; use crate::{ - add_context::LoopContextUnionsAnd, + add_context::ContextCache, print_with_intermediate_helper, schema::{Expr, RcExpr, TreeProgram}, to_egglog::TreeToEgglog, @@ -49,12 +49,12 @@ fn get_calls_with_cache( fn subst_call( call: &RcExpr, func_to_body: &HashMap, - unions: &mut LoopContextUnionsAnd<()>, + cache: &mut ContextCache, ) -> CallBody { if let Expr::Call(func_name, args) = call.as_ref() { CallBody { call: call.clone(), - body: Expr::subst(args, func_to_body[func_name], unions), + body: Expr::subst(args, func_to_body[func_name], cache), } } else { panic!("Tried to substitute non-calls.") @@ -65,7 +65,7 @@ fn subst_call( pub fn function_inlining_pairs( program: &TreeProgram, iterations: usize, - unions: &mut LoopContextUnionsAnd<()>, + cache: &mut ContextCache, ) -> Vec { if iterations == 0 { return vec![]; @@ -94,7 +94,7 @@ pub fn function_inlining_pairs( let mut inlined_calls = calls .iter() - .map(|call| subst_call(call, &func_name_to_body, unions)) + .map(|call| subst_call(call, &func_name_to_body, cache)) .collect::>(); // Repeat! Get calls and subst for each new substituted body. @@ -114,7 +114,7 @@ pub fn function_inlining_pairs( // Only work on new calls, added from the new inlines new_inlines = new_calls .iter() - .map(|call| subst_call(call, &func_name_to_body, unions)) + .map(|call| subst_call(call, &func_name_to_body, cache)) .collect::>(); inlined_calls.extend(new_inlines.clone()); } @@ -202,7 +202,7 @@ fn test_function_inlining_pairs() { let program = program!(main, inc_twice, inc); - let pairs = function_inlining_pairs(&program, iterations); + let pairs = function_inlining_pairs(&program, iterations, &mut ContextCache::new()); // First iteration: // call inc_twice 1 --> call inc (call inc 1) ... so the new calls are call inc (call inc 1), call inc 1 @@ -216,7 +216,7 @@ fn test_function_inlining_pairs() { // No more iterations! - assert_eq!(pairs.value.len(), 6) + assert_eq!(pairs.len(), 6) } // Infinite recursion should produce as many pairs as iterations @@ -233,7 +233,7 @@ fn test_inf_recursion_function_inlining_pairs() { .to_program(base(intt()), base(intt())); for iterations in 0..10 { - let pairs = function_inlining_pairs(&program, iterations); - assert_eq!(pairs.value.len(), iterations); + let pairs = function_inlining_pairs(&program, iterations, &mut ContextCache::new()); + assert_eq!(pairs.len(), iterations); } } diff --git a/dag_in_context/src/optimizations/loop_unroll.rs b/dag_in_context/src/optimizations/loop_unroll.rs index 850632c6..7f18918a 100644 --- a/dag_in_context/src/optimizations/loop_unroll.rs +++ b/dag_in_context/src/optimizations/loop_unroll.rs @@ -49,7 +49,7 @@ fn loop_unroll_simple() -> crate::Result { egglog_test( &format!("{prog}"), - &format!("(check (= {prog} {}))", expected.value), + &format!("(check (= {prog} {}))", expected), vec![prog.to_program(base(intt()), tuplet!(intt()))], intv(0), tuplev!(intv(8)), diff --git a/dag_in_context/src/optimizations/passthrough.rs b/dag_in_context/src/optimizations/passthrough.rs index 6cfa2763..4f45fab4 100644 --- a/dag_in_context/src/optimizations/passthrough.rs +++ b/dag_in_context/src/optimizations/passthrough.rs @@ -73,16 +73,15 @@ fn passthrough_if_predicate() -> crate::Result { ); let check = less_than(arg(), int(5)); - let build = build.to_program(base(intt()), base(boolt())).add_context(); - let check = check.to_program(base(intt()), base(boolt())).add_context(); + let (build, build_cache) = build.to_program(base(intt()), base(boolt())).add_context(); + let (check, check_cache) = check.to_program(base(intt()), base(boolt())).add_context(); egglog_test( - &format!("(let b {})\n{}", build.value, build.get_unions()), + &format!("(let b {build})\n{}", build_cache.get_unions()), &format!( - "(let c {})\n{} (check (= b c))", - check.value, - check.get_unions() + "(let c {check})\n{} (check (= b c))", + check_cache.get_unions() ), - vec![build.value, check.value], + vec![build, check], intv(3), val_bool(true), vec![], diff --git a/dag_in_context/src/optimizations/switch_rewrites.rs b/dag_in_context/src/optimizations/switch_rewrites.rs index bc294226..1ffddd17 100644 --- a/dag_in_context/src/optimizations/switch_rewrites.rs +++ b/dag_in_context/src/optimizations/switch_rewrites.rs @@ -4,11 +4,10 @@ use crate::egglog_test; #[test] fn switch_rewrite_three_quarters_and() -> crate::Result { use crate::ast::*; - use crate::schema::Assumption; let build = tif(and(tfalse(), ttrue()), empty(), int(1), int(2)) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); let check = tif( tfalse(), @@ -17,15 +16,11 @@ fn switch_rewrite_three_quarters_and() -> crate::Result { int(2), ) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); egglog_test( - &format!("(let build_ {})\n{}", build.value, build.get_unions()), - &format!( - "(let check_ {})\n{}\n(check (= build_ check_))", - check.value, - check.get_unions() - ), + &format!("(let build_ {build})"), + &format!("(let check_ {check})\n(check (= build_ check_))",), vec![], val_empty(), intv(2), @@ -36,11 +31,10 @@ fn switch_rewrite_three_quarters_and() -> crate::Result { #[test] fn switch_rewrite_three_quarters_or() -> crate::Result { use crate::ast::*; - use crate::schema::Assumption; let build = tif(or(tfalse(), ttrue()), empty(), int(1), int(2)) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); let check = tif( tfalse(), @@ -49,15 +43,11 @@ fn switch_rewrite_three_quarters_or() -> crate::Result { tif(get(arg(), 0), empty(), int(1), int(2)), ) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); egglog_test( - &format!("(let build_ {})\n{}", build.value, build.get_unions()), - &format!( - "(let check_ {})\n{}\n(check (= build_ check_))", - check.value, - check.get_unions() - ), + &format!("(let build_ {build})"), + &format!("(let check_ {check})\n(check (= build_ check_))"), vec![], val_empty(), intv(1), @@ -68,7 +58,6 @@ fn switch_rewrite_three_quarters_or() -> crate::Result { #[test] fn switch_rewrite_forward_pred() -> crate::Result { use crate::ast::*; - use crate::schema::Assumption; let ctx_ty = tuplet!(boolt()); @@ -79,20 +68,13 @@ fn switch_rewrite_forward_pred() -> crate::Result { 0, ) .add_arg_type(ctx_ty.clone()) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); - let check = arg - .clone() - .add_arg_type(ctx_ty.clone()) - .add_ctx(Assumption::dummy()); + let check = arg.clone().add_arg_type(ctx_ty.clone()).add_dummy_ctx(); egglog_test( - &format!("(let build_ {})\n{}", build.value, build.get_unions()), - &format!( - "(let check_ {})\n{}\n(check (= build_ check_))", - check.value, - check.get_unions() - ), + &format!("(let build_ {build})"), + &format!("(let check_ {check})\n(check (= build_ check_))"), vec![], val_empty(), intv(1), @@ -103,7 +85,6 @@ fn switch_rewrite_forward_pred() -> crate::Result { #[test] fn switch_rewrite_negate_pred() -> crate::Result { use crate::ast::*; - use crate::schema::Assumption; let ctx_ty = tuplet!(boolt()); @@ -114,19 +95,15 @@ fn switch_rewrite_negate_pred() -> crate::Result { 0, ) .add_arg_type(ctx_ty.clone()) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); let check = not(arg.clone()) .add_arg_type(ctx_ty.clone()) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); egglog_test( - &format!("(let build_ {})\n{}", build.value, build.get_unions()), - &format!( - "(let check_ {})\n{}\n(check (= build_ check_))", - check.value, - check.get_unions() - ), + &format!("(let build_ {build})"), + &format!("(let check_ {check})\n(check (= build_ check_))"), vec![], val_empty(), intv(1), @@ -137,7 +114,6 @@ fn switch_rewrite_negate_pred() -> crate::Result { #[test] fn single_branch_switch() -> crate::Result { use crate::ast::*; - use crate::schema::Assumption; let build = switch_vec( int(1), @@ -148,19 +124,15 @@ fn single_branch_switch() -> crate::Result { ], ) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); let check = int(1) .with_arg_types(emptyt(), base(intt())) - .add_ctx(Assumption::dummy()); + .add_dummy_ctx(); egglog_test( - &format!("(let build_ {})\n{}", build.value, build.get_unions()), - &format!( - "(let check_ {})\n{}\n(check (!= build_ check_))", - check.value, - check.get_unions() - ), + &format!("(let build_ {build})"), + &format!("(let check_ {check})\n(check (!= build_ check_))"), vec![], val_empty(), intv(1), diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 12c6e3ea..24907929 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -6,7 +6,7 @@ use std::{ use strum_macros::EnumIter; use crate::{ - add_context::LoopContextUnionsAnd, + add_context::ContextCache, ast::{base, boolt, floatt, inif, inloop, inswitch, intt}, schema::{ Assumption, BaseType, BinaryOp, Constant, Expr, RcExpr, TernaryOp, TreeProgram, Type, @@ -289,12 +289,19 @@ impl Expr { } // Substitute "arg" for Arg() in within. Also replaces context with "arg"'s context. - pub fn subst(arg: &RcExpr, within: &RcExpr, unions: &mut LoopContextUnionsAnd<()>) -> RcExpr { + pub fn subst(arg: &RcExpr, within: &RcExpr, context_cache: &mut ContextCache) -> RcExpr { let mut subst_cache: HashMap<*const Expr, RcExpr> = HashMap::new(); let arg_ty = arg.get_arg_type(); let arg_ctx = arg.get_ctx(); - Self::subst_with_cache(arg, &arg_ty, arg_ctx, within, &mut subst_cache, unions) + Self::subst_with_cache( + arg, + &arg_ty, + arg_ctx, + within, + &mut subst_cache, + context_cache, + ) } fn subst_with_cache( @@ -303,14 +310,8 @@ impl Expr { arg_ctx: &Assumption, within: &RcExpr, subst_cache: &mut HashMap<*const Expr, RcExpr>, - unions: &mut LoopContextUnionsAnd<()>, + context_cache: &mut ContextCache, ) -> RcExpr { - let add_ctx = |expr: &RcExpr, unions: &mut LoopContextUnionsAnd<()>, assumption| { - let unions_and_value = expr.add_ctx(assumption); - unions.unions.extend(unions_and_value.unions); - unions_and_value.value - }; - if let Some(substed) = subst_cache.get(&Rc::as_ptr(within)) { return substed.clone(); } @@ -322,32 +323,32 @@ impl Expr { // Propagate through current scope Expr::Top(op, x, y, z) => Rc::new(Expr::Top( op.clone(), - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), - Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, unions), - Self::subst_with_cache(arg, arg_ty, arg_ctx, z, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), + Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, context_cache), + Self::subst_with_cache(arg, arg_ty, arg_ctx, z, subst_cache, context_cache), )), Expr::Bop(op, x, y) => Rc::new(Expr::Bop( op.clone(), - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), - Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), + Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, context_cache), )), Expr::Uop(op, x) => Rc::new(Expr::Uop( op.clone(), - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), )), Expr::Get(x, i) => Rc::new(Expr::Get( - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), *i, )), Expr::Alloc(amt, x, y, ty) => Rc::new(Expr::Alloc( *amt, - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), - Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), + Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, context_cache), ty.clone(), )), Expr::Call(name, x) => Rc::new(Expr::Call( name.clone(), - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), )), Expr::Single(x) => Rc::new(Expr::Single(Self::subst_with_cache( arg, @@ -355,61 +356,62 @@ impl Expr { arg_ctx, x, subst_cache, - unions, + context_cache, ))), Expr::Concat(x, y) => Rc::new(Expr::Concat( - Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, unions), - Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, x, subst_cache, context_cache), + Self::subst_with_cache(arg, arg_ty, arg_ctx, y, subst_cache, context_cache), )), Expr::If(pred, input, then, els) => { let new_pred = - Self::subst_with_cache(arg, arg_ty, arg_ctx, pred, subst_cache, unions); + Self::subst_with_cache(arg, arg_ty, arg_ctx, pred, subst_cache, context_cache); let new_input = - Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, unions); + Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, context_cache); Rc::new(Expr::If( new_pred.clone(), new_input.clone(), - add_ctx( - then, - unions, + then.add_ctx_with_cache( inif(true, new_pred.clone(), new_input.clone()), + context_cache, ), - add_ctx(els, unions, inif(false, new_pred, new_input)), + els.add_ctx_with_cache(inif(false, new_pred, new_input), context_cache), )) } Expr::Switch(pred, input, branches) => { let new_pred = - Self::subst_with_cache(arg, arg_ty, arg_ctx, pred, subst_cache, unions); + Self::subst_with_cache(arg, arg_ty, arg_ctx, pred, subst_cache, context_cache); let new_input = - Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, unions); + Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, context_cache); let new_branches = branches .iter() .enumerate() .map(|(i, branch)| { - add_ctx( - branch, - unions, + branch.add_ctx_with_cache( inswitch(i.try_into().unwrap(), new_pred.clone(), new_input.clone()), + context_cache, ) }) .collect(); Rc::new(Expr::Switch(new_pred, new_input, new_branches)) } - Expr::DoWhile(input, body) => { - let new_input = - Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, unions); - Rc::new(Expr::DoWhile( - new_input.clone(), - // It may seem odd to use the old body in the new context, but this is how - // it's done in add_ctx. - add_ctx(body, unions, inloop(new_input, body.clone())), - )) + Expr::DoWhile(input, pred_and_body) => { + let placeholder = context_cache.new_placeholder(); + + let new_inputs = + Self::subst_with_cache(arg, arg_ty, arg_ctx, input, subst_cache, context_cache); + let new_pred_and_body = + pred_and_body.add_ctx_with_cache(placeholder.clone(), context_cache); + + let new_ctx = inloop(new_inputs.clone(), new_pred_and_body.clone()); + context_cache.push_loop_context_union(placeholder, new_ctx); + + RcExpr::new(Expr::DoWhile(new_inputs, new_pred_and_body)) } Expr::Function(x, y, z, body) => Rc::new(Expr::Function( x.clone(), y.clone(), z.clone(), - Self::subst_with_cache(arg, arg_ty, arg_ctx, body, subst_cache, unions), + Self::subst_with_cache(arg, arg_ty, arg_ctx, body, subst_cache, context_cache), )), // For leaves, replace the type and context diff --git a/src/rvsdg/to_dag.rs b/src/rvsdg/to_dag.rs index e281580d..923098c6 100644 --- a/src/rvsdg/to_dag.rs +++ b/src/rvsdg/to_dag.rs @@ -17,7 +17,7 @@ use dag_in_context::schema::Constant; use crate::rvsdg::{BasicExpr, Id, Operand, RvsdgBody, RvsdgFunction, RvsdgProgram}; use bril_rs::{EffectOps, Literal, ValueOps}; use dag_in_context::{ - add_context::LoopContextUnionsAnd, + add_context::ContextCache, ast::{add, call, dowhile, function, int, less_than, program_vec, tfalse, ttrue}, schema::{RcExpr, TreeProgram, Type}, }; @@ -30,7 +30,7 @@ impl RvsdgProgram { /// Common subexpressions are shared by the same Rc in the dag encoding. /// This invariant is maintained by restore_sharing_invariant. /// Also adds context to the program. - pub fn to_dag_encoding(&self, add_context: bool) -> LoopContextUnionsAnd { + pub fn to_dag_encoding(&self, add_context: bool) -> (TreeProgram, ContextCache) { let last_function = self.functions.last().unwrap(); let rest_functions = self.functions.iter().take(self.functions.len() - 1); let res = program_vec( @@ -43,7 +43,7 @@ impl RvsdgProgram { if add_context { res.add_context() } else { - LoopContextUnionsAnd::new().swap_value(res).0 + (res, ContextCache::new()) } } } @@ -409,9 +409,9 @@ fn dag_translation_test( let prog = parse_from_string(program); let cfg = program_to_cfg(&prog); let rvsdg = cfg_to_rvsdg(&cfg).unwrap(); - let result = rvsdg.to_dag_encoding(false); + let result = rvsdg.to_dag_encoding(false).0; - assert_progs_eq(&result.value, &expected, "Resulting program is incorrect"); + assert_progs_eq(&result, &expected, "Resulting program is incorrect"); let (found_val, found_printlog) = interpret_dag_prog(&expected, &input_val); assert_eq!( @@ -425,7 +425,7 @@ fn dag_translation_test( expected_printlog, found_printlog ); - let (found_val, found_printlog) = interpret_dag_prog(&result.value, &input_val); + let (found_val, found_printlog) = interpret_dag_prog(&result, &input_val); assert_eq!( expected_val, found_val, "Resulting program produced incorrect result. Expected {:?}, found {:?}", diff --git a/src/util.rs b/src/util.rs index eb0de2da..615219d8 100644 --- a/src/util.rs +++ b/src/util.rs @@ -354,8 +354,8 @@ pub struct RunOutput { impl Run { fn optimize_bril(program: &Program) -> Result { let rvsdg = Optimizer::program_to_rvsdg(program)?; - let dag = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?; + let (dag, mut cache) = rvsdg.to_dag_encoding(true); + let optimized = dag_in_context::optimize(&dag, &mut cache).map_err(EggCCError::EggLog)?; let rvsdg2 = dag_to_rvsdg(&optimized); let cfg = rvsdg2.to_cfg(); let bril = cfg.to_bril(); @@ -568,8 +568,8 @@ impl Run { } RunType::DagToRvsdg => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - let rvsdg2 = dag_to_rvsdg(&tree.value); + let (tree, _cache) = rvsdg.to_dag_encoding(true); + let rvsdg2 = dag_to_rvsdg(&tree); ( vec![Visualization { result: rvsdg2.to_svg(), @@ -581,8 +581,8 @@ impl Run { } RunType::DagRoundTrip => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - let rvsdg2 = dag_to_rvsdg(&tree.value); + let (tree, _cache) = rvsdg.to_dag_encoding(true); + let rvsdg2 = dag_to_rvsdg(&tree); let cfg = rvsdg2.to_cfg(); let bril = cfg.to_bril(); ( @@ -596,8 +596,8 @@ impl Run { } RunType::CheckExtractIdentical => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - check_roundtrip_egraph(&tree.value); + let (tree, _cache) = rvsdg.to_dag_encoding(true); + check_roundtrip_egraph(&tree); (vec![], None) } RunType::Optimize => { @@ -613,20 +613,21 @@ impl Run { } RunType::DagConversion => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); + let (tree, _cache) = rvsdg.to_dag_encoding(true); ( vec![Visualization { - result: tree_to_svg(&tree.value), + result: tree_to_svg(&tree), file_extension: ".svg".to_string(), name: "".to_string(), }], - Some(Interpretable::TreeProgram(tree.value)), + Some(Interpretable::TreeProgram(tree)), ) } RunType::DagOptimize => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let tree = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(tree).map_err(EggCCError::EggLog)?; + let (tree, mut cache) = rvsdg.to_dag_encoding(true); + let optimized = + dag_in_context::optimize(&tree, &mut cache).map_err(EggCCError::EggLog)?; ( vec![Visualization { result: tree_to_svg(&optimized), @@ -638,8 +639,9 @@ impl Run { } RunType::OptimizedRvsdg => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let dag = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?; + let (dag, mut cache) = rvsdg.to_dag_encoding(true); + let optimized = + dag_in_context::optimize(&dag, &mut cache).map_err(EggCCError::EggLog)?; let rvsdg = dag_to_rvsdg(&optimized); ( vec![Visualization { @@ -652,8 +654,8 @@ impl Run { } RunType::Egglog => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; - let dag = rvsdg.to_dag_encoding(true); - let egglog = build_program(dag, true); + let (dag, mut cache) = rvsdg.to_dag_encoding(true); + let egglog = build_program(&dag, &mut cache, true); ( vec![Visualization { result: egglog,