Skip to content

Commit

Permalink
Fix subst types, improve egglog output
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed Jun 3, 2024
1 parent 45ef932 commit 64d2186
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 33 deletions.
27 changes: 20 additions & 7 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ pub(crate) fn print_with_intermediate_vars(termdag: &TermDag, term: Term) -> Str
}

// It is expected that program has context added
pub fn build_program(program: &LoopContextUnionsAnd<TreeProgram>, optimize: bool) -> String {
pub fn build_program(program: LoopContextUnionsAnd<TreeProgram>, optimize: bool) -> String {
let (mut unions, program) = program.swap_value(());
let mut printed = String::new();

// Create a global cache for generating intermediate variables
Expand All @@ -127,8 +128,9 @@ pub fn build_program(program: &LoopContextUnionsAnd<TreeProgram>, optimize: bool
} else {
function_inlining::print_function_inlining_pairs(
function_inlining::function_inlining_pairs(
&program.value,
&program,
config::FUNCTION_INLINING_ITERATIONS,
&mut unions,
),
&mut printed,
&mut tree_state,
Expand All @@ -137,12 +139,12 @@ pub fn build_program(program: &LoopContextUnionsAnd<TreeProgram>, optimize: bool
};

// Generate program egglog
let term = program.value.to_egglog_with(&mut tree_state);
let term = program.to_egglog_with(&mut tree_state);
let res =
print_with_intermediate_helper(&tree_state.termdag, term, &mut term_cache, &mut printed);

let loop_context_unions =
program.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache);
unions.get_unions_with_sharing(&mut printed, &mut tree_state, &mut term_cache);

let prologue = prologue();

Expand All @@ -154,11 +156,21 @@ pub fn build_program(program: &LoopContextUnionsAnd<TreeProgram>, optimize: bool

format!(
"
; Prologue
{prologue}
; Program nodes
{printed}
; Program root
(let PROG {res})
; Loop context unions
{loop_context_unions}
; Function inlining unions
{function_inlining_unions}
; Schedule
{schedule}
"
)
Expand All @@ -176,7 +188,7 @@ pub fn are_progs_eq(program1: TreeProgram, program2: TreeProgram) -> bool {
pub fn check_roundtrip_egraph(program: &TreeProgram) {
let mut termdag = egglog::TermDag::default();
let egglog_prog = build_program(
&LoopContextUnionsAnd::new().swap_value(program.clone()).0,
LoopContextUnionsAnd::new().swap_value(program.clone()).0,
false,
);
log::info!("Running egglog program...");
Expand Down Expand Up @@ -204,8 +216,9 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) {

// It is expected that program has context added
pub fn optimize(
program: &LoopContextUnionsAnd<TreeProgram>,
program: LoopContextUnionsAnd<TreeProgram>,
) -> std::result::Result<TreeProgram, egglog::Error> {
let original_program = program.value.clone();
let egglog_prog = build_program(program, true);
log::info!("Running egglog program...");
let mut egraph = egglog::EGraph::default();
Expand All @@ -214,7 +227,7 @@ pub fn optimize(
let (serialized, unextractables) = serialized_egraph(egraph);
let mut termdag = egglog::TermDag::default();
let (_res_cost, res) = extract(
&program.value,
&original_program,
serialized,
unextractables,
&mut termdag,
Expand Down
25 changes: 9 additions & 16 deletions dag_in_context/src/optimizations/function_inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,11 +52,9 @@ fn subst_call(
unions: &mut LoopContextUnionsAnd<()>,
) -> CallBody {
if let Expr::Call(func_name, args) = call.as_ref() {
let unions_and_value = Expr::subst(args, func_to_body[func_name]);
unions.unions.extend(unions_and_value.unions);
CallBody {
call: call.clone(),
body: unions_and_value.value,
body: Expr::subst(args, func_to_body[func_name], unions),
}
} else {
panic!("Tried to substitute non-calls.")
Expand All @@ -67,11 +65,10 @@ fn subst_call(
pub fn function_inlining_pairs(
program: &TreeProgram,
iterations: usize,
) -> LoopContextUnionsAnd<Vec<CallBody>> {
let mut unions = LoopContextUnionsAnd::new();

unions: &mut LoopContextUnionsAnd<()>,
) -> Vec<CallBody> {
if iterations == 0 {
return unions.swap_value(Vec::new()).0;
return vec![];
}

let mut all_funcs = vec![program.entry.clone()];
Expand All @@ -97,7 +94,7 @@ pub fn function_inlining_pairs(

let mut inlined_calls = calls
.iter()
.map(|call| subst_call(call, &func_name_to_body, &mut unions))
.map(|call| subst_call(call, &func_name_to_body, unions))
.collect::<Vec<_>>();

// Repeat! Get calls and subst for each new substituted body.
Expand All @@ -117,25 +114,24 @@ pub fn function_inlining_pairs(
// Only work on new calls, added from the new inlines
new_inlines = new_calls
.iter()
.map(|call| subst_call(call, &func_name_to_body, &mut unions))
.map(|call| subst_call(call, &func_name_to_body, unions))
.collect::<Vec<CallBody>>();
inlined_calls.extend(new_inlines.clone());
}

unions.swap_value(inlined_calls).0
inlined_calls
}

// Returns a formatted string of (union call body) for each pair
pub fn print_function_inlining_pairs(
function_inlining_pairs: LoopContextUnionsAnd<Vec<CallBody>>,
function_inlining_pairs: Vec<CallBody>,
printed: &mut String,
tree_state: &mut TreeToEgglog,
term_cache: &mut HashMap<Term, String>,
) -> String {
let inlined_calls = "(relation InlinedCall (String Expr))";
// Get unions and mark each call as inlined for extraction purposes
let printed_pairs = function_inlining_pairs
.value
.iter()
.map(|cb| {
if let Expr::Call(callee, _) = cb.call.as_ref() {
Expand Down Expand Up @@ -177,10 +173,7 @@ pub fn print_function_inlining_pairs(
})
.collect::<Vec<_>>()
.join("\n");
format!(
"{inlined_calls} {printed_pairs} {}",
function_inlining_pairs.get_unions_with_sharing(printed, tree_state, term_cache)
)
format!("{inlined_calls} {printed_pairs}")
}

// Check that function inling pairs produces the right number of pairs for
Expand Down
8 changes: 2 additions & 6 deletions dag_in_context/src/schema_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,16 +289,12 @@ impl Expr {
}

// Substitute "arg" for Arg() in within. Also replaces context with "arg"'s context.
pub fn subst(arg: &RcExpr, within: &RcExpr) -> LoopContextUnionsAnd<RcExpr> {
pub fn subst(arg: &RcExpr, within: &RcExpr, unions: &mut LoopContextUnionsAnd<()>) -> RcExpr {
let mut subst_cache: HashMap<*const Expr, RcExpr> = HashMap::new();
let mut unions = LoopContextUnionsAnd::new();

let arg_ty = arg.get_arg_type();
let arg_ctx = arg.get_ctx();
let value =
Self::subst_with_cache(arg, &arg_ty, arg_ctx, within, &mut subst_cache, &mut unions);

unions.swap_value(value).0
Self::subst_with_cache(arg, &arg_ty, arg_ctx, within, &mut subst_cache, unions)
}

fn subst_with_cache(
Expand Down
8 changes: 4 additions & 4 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ impl Run {
fn optimize_bril(program: &Program) -> Result<Program, EggCCError> {
let rvsdg = Optimizer::program_to_rvsdg(program)?;
let dag = rvsdg.to_dag_encoding(true);
let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?;
let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?;
let rvsdg2 = dag_to_rvsdg(&optimized);
let cfg = rvsdg2.to_cfg();
let bril = cfg.to_bril();
Expand Down Expand Up @@ -626,7 +626,7 @@ impl Run {
RunType::DagOptimize => {
let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?;
let tree = rvsdg.to_dag_encoding(true);
let optimized = dag_in_context::optimize(&tree).map_err(EggCCError::EggLog)?;
let optimized = dag_in_context::optimize(tree).map_err(EggCCError::EggLog)?;
(
vec![Visualization {
result: tree_to_svg(&optimized),
Expand All @@ -639,7 +639,7 @@ impl Run {
RunType::OptimizedRvsdg => {
let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?;
let dag = rvsdg.to_dag_encoding(true);
let optimized = dag_in_context::optimize(&dag).map_err(EggCCError::EggLog)?;
let optimized = dag_in_context::optimize(dag).map_err(EggCCError::EggLog)?;
let rvsdg = dag_to_rvsdg(&optimized);
(
vec![Visualization {
Expand All @@ -653,7 +653,7 @@ impl Run {
RunType::Egglog => {
let rvsdg = Optimizer::program_to_rvsdg(&self.prog_with_args.program)?;
let dag = rvsdg.to_dag_encoding(true);
let egglog = build_program(&dag, true);
let egglog = build_program(dag, true);
(
vec![Visualization {
result: egglog,
Expand Down

0 comments on commit 64d2186

Please sign in to comment.