diff --git a/dag_in_context/src/lib.rs b/dag_in_context/src/lib.rs index 9a11f860..3682d2b4 100644 --- a/dag_in_context/src/lib.rs +++ b/dag_in_context/src/lib.rs @@ -114,7 +114,8 @@ pub(crate) fn print_with_intermediate_vars(termdag: &TermDag, term: Term) -> Str } // It is expected that program has context added -pub fn build_program(program: &LoopContextUnionsAnd, optimize: bool) -> String { +pub fn build_program(program: LoopContextUnionsAnd, optimize: bool) -> String { + let (mut unions, program) = program.swap_value(()); let mut printed = String::new(); // Create a global cache for generating intermediate variables @@ -127,8 +128,9 @@ pub fn build_program(program: &LoopContextUnionsAnd, optimize: bool } else { function_inlining::print_function_inlining_pairs( function_inlining::function_inlining_pairs( - &program.value, + &program, config::FUNCTION_INLINING_ITERATIONS, + &mut unions, ), &mut printed, &mut tree_state, @@ -137,12 +139,12 @@ pub fn build_program(program: &LoopContextUnionsAnd, optimize: bool }; // Generate program egglog - let term = program.value.to_egglog_with(&mut tree_state); + let term = program.to_egglog_with(&mut tree_state); let res = print_with_intermediate_helper(&tree_state.termdag, term, &mut term_cache, &mut printed); let loop_context_unions = - program.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache); + unions.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache); let prologue = prologue(); @@ -154,11 +156,21 @@ pub fn build_program(program: &LoopContextUnionsAnd, optimize: bool format!( " +; Prologue {prologue} + +; Program nodes {printed} +; Program root (let PROG {res}) + +; Loop context unions {loop_context_unions} + +; Function inlining unions {function_inlining_unions} + +; Schedule {schedule} " ) @@ -176,7 +188,7 @@ pub fn are_progs_eq(program1: TreeProgram, program2: TreeProgram) -> bool { pub fn check_roundtrip_egraph(program: &TreeProgram) { let mut termdag = egglog::TermDag::default(); let egglog_prog = build_program( - &LoopContextUnionsAnd::new().swap_value(program.clone()).0, + LoopContextUnionsAnd::new().swap_value(program.clone()).0, false, ); log::info!("Running egglog program..."); @@ -204,8 +216,9 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) { // It is expected that program has context added pub fn optimize( - program: &LoopContextUnionsAnd, + program: LoopContextUnionsAnd, ) -> std::result::Result { + let original_program = program.value.clone(); let egglog_prog = build_program(program, true); log::info!("Running egglog program..."); let mut egraph = egglog::EGraph::default(); @@ -214,7 +227,7 @@ pub fn optimize( let (serialized, unextractables) = serialized_egraph(egraph); let mut termdag = egglog::TermDag::default(); let (_res_cost, res) = extract( - &program.value, + &original_program, serialized, unextractables, &mut termdag, diff --git a/dag_in_context/src/optimizations/function_inlining.rs b/dag_in_context/src/optimizations/function_inlining.rs index b77dad4d..5b68b7ff 100644 --- a/dag_in_context/src/optimizations/function_inlining.rs +++ b/dag_in_context/src/optimizations/function_inlining.rs @@ -52,11 +52,9 @@ fn subst_call( unions: &mut LoopContextUnionsAnd<()>, ) -> CallBody { if let Expr::Call(func_name, args) = call.as_ref() { - let unions_and_value = Expr::subst(args, func_to_body[func_name]); - unions.unions.extend(unions_and_value.unions); CallBody { call: call.clone(), - body: unions_and_value.value, + body: Expr::subst(args, func_to_body[func_name], unions), } } else { panic!("Tried to substitute non-calls.") @@ -67,11 +65,10 @@ fn subst_call( pub fn function_inlining_pairs( program: &TreeProgram, iterations: usize, -) -> LoopContextUnionsAnd> { - let mut unions = LoopContextUnionsAnd::new(); - + unions: &mut LoopContextUnionsAnd<()>, +) -> Vec { if iterations == 0 { - return unions.swap_value(Vec::new()).0; + return vec![]; } let mut all_funcs = vec![program.entry.clone()]; @@ -97,7 +94,7 @@ pub fn function_inlining_pairs( let mut inlined_calls = calls .iter() - .map(|call| subst_call(call, &func_name_to_body, &mut unions)) + .map(|call| subst_call(call, &func_name_to_body, unions)) .collect::>(); // Repeat! Get calls and subst for each new substituted body. @@ -117,17 +114,17 @@ pub fn function_inlining_pairs( // Only work on new calls, added from the new inlines new_inlines = new_calls .iter() - .map(|call| subst_call(call, &func_name_to_body, &mut unions)) + .map(|call| subst_call(call, &func_name_to_body, unions)) .collect::>(); inlined_calls.extend(new_inlines.clone()); } - unions.swap_value(inlined_calls).0 + inlined_calls } // Returns a formatted string of (union call body) for each pair pub fn print_function_inlining_pairs( - function_inlining_pairs: LoopContextUnionsAnd>, + function_inlining_pairs: Vec, printed: &mut String, tree_state: &mut TreeToEgglog, term_cache: &mut HashMap, @@ -135,7 +132,6 @@ pub fn print_function_inlining_pairs( let inlined_calls = "(relation InlinedCall (String Expr))"; // Get unions and mark each call as inlined for extraction purposes let printed_pairs = function_inlining_pairs - .value .iter() .map(|cb| { if let Expr::Call(callee, _) = cb.call.as_ref() { @@ -177,10 +173,7 @@ pub fn print_function_inlining_pairs( }) .collect::>() .join("\n"); - format!( - "{inlined_calls} {printed_pairs} {}", - function_inlining_pairs.get_unions_with_sharing(printed, tree_state, term_cache) - ) + format!("{inlined_calls} {printed_pairs}") } // Check that function inling pairs produces the right number of pairs for diff --git a/dag_in_context/src/schema_helpers.rs b/dag_in_context/src/schema_helpers.rs index 42918044..12c6e3ea 100644 --- a/dag_in_context/src/schema_helpers.rs +++ b/dag_in_context/src/schema_helpers.rs @@ -289,16 +289,12 @@ impl Expr { } // Substitute "arg" for Arg() in within. Also replaces context with "arg"'s context. - pub fn subst(arg: &RcExpr, within: &RcExpr) -> LoopContextUnionsAnd { + pub fn subst(arg: &RcExpr, within: &RcExpr, unions: &mut LoopContextUnionsAnd<()>) -> RcExpr { let mut subst_cache: HashMap<*const Expr, RcExpr> = HashMap::new(); - let mut unions = LoopContextUnionsAnd::new(); let arg_ty = arg.get_arg_type(); let arg_ctx = arg.get_ctx(); - let value = - Self::subst_with_cache(arg, &arg_ty, arg_ctx, within, &mut subst_cache, &mut unions); - - unions.swap_value(value).0 + Self::subst_with_cache(arg, &arg_ty, arg_ctx, within, &mut subst_cache, unions) } fn subst_with_cache( diff --git a/src/util.rs b/src/util.rs index 51e6a417..eb0de2da 100644 --- a/src/util.rs +++ b/src/util.rs @@ -355,7 +355,7 @@ impl Run { fn optimize_bril(program: &Program) -> Result { let rvsdg = Optimizer::program_to_rvsdg(program)?; let dag = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; + let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?; let rvsdg2 = dag_to_rvsdg(&optimized); let cfg = rvsdg2.to_cfg(); let bril = cfg.to_bril(); @@ -626,7 +626,7 @@ impl Run { RunType::DagOptimize => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let tree = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(&tree).map_err(EggCCError::EggLog)?; + let optimized = dag_in_context::optimize(tree).map_err(EggCCError::EggLog)?; ( vec![Visualization { result: tree_to_svg(&optimized), @@ -639,7 +639,7 @@ impl Run { RunType::OptimizedRvsdg => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let dag = rvsdg.to_dag_encoding(true); - let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?; + let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?; let rvsdg = dag_to_rvsdg(&optimized); ( vec![Visualization { @@ -653,7 +653,7 @@ impl Run { RunType::Egglog => { let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?; let dag = rvsdg.to_dag_encoding(true); - let egglog = build_program(&dag, true); + let egglog = build_program(dag, true); ( vec![Visualization { result: egglog,