Skip to content

Commit

Permalink
Refactor passes dag_in_context tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex-Fischman committed May 29, 2024
1 parent 51ad6d9 commit 02b0735
Show file tree
Hide file tree
Showing 13 changed files with 299 additions and 157 deletions.
81 changes: 51 additions & 30 deletions dag_in_context/src/add_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,13 @@ impl ContextCache {
}

pub struct UnionsAnd<T> {
unions: Vec<(String, String)>,
// marked as public but you probably want `get_unions`
pub unions: Vec<(String, String)>,
pub value: T,
}

impl<T> UnionsAnd<T> {
fn get_unions(&self) -> String {
pub fn get_unions(&self) -> String {
use std::fmt::Write;

self.unions
Expand All @@ -83,35 +84,35 @@ impl<T> UnionsAnd<T> {
}

impl TreeProgram {
pub fn add_context(&self) -> TreeProgram {
self.add_context_internal(Expr::func_get_ctx, &mut ContextCache::new())
pub fn add_context(&self) -> UnionsAnd<TreeProgram> {
self.add_context_internal(Expr::func_get_ctx, ContextCache::new())
}

/// add stand-in variables for all the contexts in the program
/// useful for testing if you don't care about context in the test
pub fn add_symbolic_ctx(&self) -> TreeProgram {
self.add_context_internal(
|_| Assumption::dummy(),
&mut ContextCache::new_symbolic_ctx(),
)
pub fn add_symbolic_ctx(&self) -> UnionsAnd<TreeProgram> {
self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_symbolic_ctx())
}

pub fn add_dummy_ctx(&self) -> TreeProgram {
self.add_context_internal(|_| Assumption::dummy(), &mut ContextCache::new_dummy_ctx())
pub fn add_dummy_ctx(&self) -> UnionsAnd<TreeProgram> {
self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_dummy_ctx())
}

fn add_context_internal(
&self,
func: impl Fn(&RcExpr) -> Assumption,
cache: &mut ContextCache,
) -> TreeProgram {
TreeProgram {
functions: self
.functions
.iter()
.map(|f| f.add_ctx_with_cache(func(f), cache))
.collect(),
entry: self.entry.add_ctx_with_cache(func(&self.entry), cache),
mut cache: ContextCache,
) -> UnionsAnd<TreeProgram> {
let entry = self.entry.add_ctx_with_cache(func(&self.entry), &mut cache);
let functions = self
.functions
.iter()
.map(|f| f.add_ctx_with_cache(func(f), &mut cache))
.collect();
let value = TreeProgram { functions, entry };
UnionsAnd {
value,
unions: cache.unions,
}
}
}
Expand All @@ -124,28 +125,48 @@ impl Expr {
Assumption::InFunc(name.clone())
}

pub fn func_add_ctx(self: &RcExpr) -> RcExpr {
pub fn func_add_ctx(self: &RcExpr) -> UnionsAnd<RcExpr> {
let Expr::Function(name, arg_ty, ret_ty, body) = self.as_ref() else {
panic!("Expected Function, got {:?}", self);
};
RcExpr::new(Expr::Function(
let mut cache = ContextCache::new();
let value = RcExpr::new(Expr::Function(
name.clone(),
arg_ty.clone(),
ret_ty.clone(),
body.add_ctx_with_cache(self.func_get_ctx(), &mut ContextCache::new()),
))
body.add_ctx_with_cache(self.func_get_ctx(), &mut cache),
));
UnionsAnd {
value,
unions: cache.unions,
}
}

pub fn add_dummy_ctx(self: &RcExpr) -> RcExpr {
self.add_ctx_with_cache(Assumption::dummy(), &mut ContextCache::new_dummy_ctx())
pub fn add_dummy_ctx(self: &RcExpr) -> UnionsAnd<RcExpr> {
let mut cache = ContextCache::new_dummy_ctx();
let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache);
UnionsAnd {
value,
unions: cache.unions,
}
}

pub fn add_symbolic_ctx(self: &RcExpr) -> RcExpr {
self.add_ctx_with_cache(Assumption::dummy(), &mut ContextCache::new_symbolic_ctx())
pub fn add_symbolic_ctx(self: &RcExpr) -> UnionsAnd<RcExpr> {
let mut cache = ContextCache::new_symbolic_ctx();
let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache);
UnionsAnd {
value,
unions: cache.unions,
}
}

pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> RcExpr {
self.add_ctx_with_cache(current_ctx, &mut ContextCache::new())
pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> UnionsAnd<RcExpr> {
let mut cache = ContextCache::new();
let value = self.add_ctx_with_cache(current_ctx, &mut cache);
UnionsAnd {
value,
unions: cache.unions,
}
}

fn add_ctx_with_cache(
Expand Down
45 changes: 26 additions & 19 deletions dag_in_context/src/interval_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ fn test_add_constant_fold() -> crate::Result {
let expr2 = int_ty(3, emptyt()).add_ctx(Assumption::dummy());

egglog_test(
&format!("{expr}"),
&format!("(check (= {expr} {expr2}))"),
&format!("{}\n{}", expr.value, expr.get_unions()),
&format!(
"{}\n{}\n(check (= {} {}))",
expr2.value,
expr2.get_unions(),
expr.value,
expr2.value
),
vec![
expr.to_program(emptyt(), base(intt())),
expr2.to_program(emptyt(), base(intt())),
expr.value.to_program(emptyt(), base(intt())),
expr2.value.to_program(emptyt(), base(intt())),
],
val_empty(),
intv(3),
Expand Down Expand Up @@ -193,12 +199,12 @@ fn context_if() -> crate::Result {
let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types();
let prog = f.to_program(base(intt()), base(boolt()));
let with_context = prog.add_context();
let term = with_context.entry.func_body().unwrap();
let term = with_context.value.entry.func_body().unwrap();

egglog_test(
&format!("{with_context}"),
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"),
vec![with_context],
vec![with_context.value],
intv(4),
val_bool(false),
vec![],
Expand All @@ -211,12 +217,12 @@ fn simple_less_than() -> crate::Result {
let cond = less_eq(int_ty(0, base(intt())), int(-1));
let prog = program!(function("main", base(intt()), base(boolt()), cond.clone()),);
let with_context = prog.add_context();
let term = with_context.entry.func_body().unwrap();
let term = with_context.value.entry.func_body().unwrap();

egglog_test(
&format!("{with_context}"),
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"),
vec![with_context],
vec![with_context.value],
intv(4),
val_bool(false),
vec![],
Expand All @@ -239,15 +245,15 @@ fn context_if_rev() -> crate::Result {
let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types();
let prog = f.to_program(base(intt()), base(boolt()));
let with_context = prog.add_context();
let term = with_context.entry.func_body().unwrap();
let term = with_context.value.entry.func_body().unwrap();

egglog_test(
&format!("{with_context}"),
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!(
"
(check (= {term} (Const (Bool false) (Base (IntT)) (InFunc \"main\"))))"
),
vec![with_context],
vec![with_context.value],
intv(4),
val_bool(false),
vec![],
Expand Down Expand Up @@ -302,7 +308,7 @@ fn context_if_with_state() -> crate::Result {
.with_arg_types()
.add_context();

let body_with_ctx = prog.entry.func_body().unwrap();
let body_with_ctx = prog.value.entry.func_body().unwrap();

let expected = single(tprint(
ttrue_ty(input_type.clone()),
Expand All @@ -311,13 +317,14 @@ fn context_if_with_state() -> crate::Result {
.add_ctx(Assumption::InFunc("main".to_string()));

egglog_test(
&format!("{prog}"),
&format!("{}\n{}", prog.value, prog.get_unions()),
&format!(
"
(check (= {expected} {body_with_ctx}))
"
"{}\n{}\n(check (= {} {body_with_ctx}))",
expected.value,
expected.get_unions(),
expected.value
),
vec![prog],
vec![prog.value],
val_vec(vec![intv(4), statev()]),
val_vec(vec![statev()]),
vec!["true".to_string()],
Expand Down
4 changes: 2 additions & 2 deletions dag_in_context/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ pub fn check_roundtrip_egraph(program: &TreeProgram) {
DefaultCostModel,
);

let original_with_ctx = program.add_dummy_ctx();
let res_with_ctx = res.add_dummy_ctx();
let original_with_ctx = program.add_dummy_ctx().value;
let res_with_ctx = res.add_dummy_ctx().value;

if !are_progs_eq(original_with_ctx.clone(), res_with_ctx.clone()) {
eprintln!("Original program: {}", tree_to_svg(&original_with_ctx));
Expand Down
44 changes: 33 additions & 11 deletions dag_in_context/src/optimizations/function_inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use std::{
use egglog::Term;

use crate::{
add_context::UnionsAnd,
print_with_intermediate_helper,
schema::{Expr, RcExpr, TreeProgram},
to_egglog::TreeToEgglog,
Expand Down Expand Up @@ -45,21 +46,35 @@ fn get_calls_with_cache(

// Pairs a call with its equivalent inlined body, using the passed-in function -> body map
// to look up the body
fn subst_call(call: &RcExpr, func_to_body: &HashMap<String, &RcExpr>) -> CallBody {
fn subst_call(
call: &RcExpr,
func_to_body: &HashMap<String, &RcExpr>,
unions: &mut Vec<(String, String)>,
) -> CallBody {
if let Expr::Call(func_name, args) = call.as_ref() {
let unions_and_value = Expr::subst(args, func_to_body[func_name]);
unions.extend(unions_and_value.unions);
CallBody {
call: call.clone(),
body: Expr::subst(args, func_to_body[func_name]),
body: unions_and_value.value,
}
} else {
panic!("Tried to substitute non-calls.")
}
}

// Generates a list of (call, body) pairs (in a CallBody) that can be unioned
pub fn function_inlining_pairs(program: &TreeProgram, iterations: usize) -> Vec<CallBody> {
pub fn function_inlining_pairs(
program: &TreeProgram,
iterations: usize,
) -> UnionsAnd<Vec<CallBody>> {
let mut unions = Vec::new();

if iterations == 0 {
return vec![];
return UnionsAnd {
unions,
value: Vec::new(),
};
}

let mut all_funcs = vec![program.entry.clone()];
Expand All @@ -85,7 +100,7 @@ pub fn function_inlining_pairs(program: &TreeProgram, iterations: usize) -> Vec<

let mut inlined_calls = calls
.iter()
.map(|call| subst_call(call, &func_name_to_body))
.map(|call| subst_call(call, &func_name_to_body, &mut unions))
.collect::<Vec<_>>();

// Repeat! Get calls and subst for each new substituted body.
Expand All @@ -105,24 +120,28 @@ pub fn function_inlining_pairs(program: &TreeProgram, iterations: usize) -> Vec<
// Only work on new calls, added from the new inlines
new_inlines = new_calls
.iter()
.map(|call| subst_call(call, &func_name_to_body))
.map(|call| subst_call(call, &func_name_to_body, &mut unions))
.collect::<Vec<CallBody>>();
inlined_calls.extend(new_inlines.clone());
}

inlined_calls
UnionsAnd {
unions,
value: inlined_calls,
}
}

// Returns a formatted string of (union call body) for each pair
pub fn print_function_inlining_pairs(
function_inlining_pairs: Vec<CallBody>,
function_inlining_pairs: UnionsAnd<Vec<CallBody>>,
printed: &mut String,
tree_state: &mut TreeToEgglog,
term_cache: &mut HashMap<Term, String>,
) -> String {
let inlined_calls = "(relation InlinedCall (String Expr))";
// Get unions and mark each call as inlined for extraction purposes
let printed_pairs = function_inlining_pairs
.value
.iter()
.map(|cb| {
if let Expr::Call(callee, _) = cb.call.as_ref() {
Expand Down Expand Up @@ -164,7 +183,10 @@ pub fn print_function_inlining_pairs(
})
.collect::<Vec<_>>()
.join("\n");
format!("{inlined_calls} {printed_pairs}")
format!(
"{inlined_calls} {printed_pairs} {}",
function_inlining_pairs.get_unions()
)
}

// Check that function inling pairs produces the right number of pairs for
Expand Down Expand Up @@ -207,7 +229,7 @@ fn test_function_inlining_pairs() {

// No more iterations!

assert_eq!(pairs.len(), 6)
assert_eq!(pairs.value.len(), 6)
}

// Infinite recursion should produce as many pairs as iterations
Expand All @@ -225,6 +247,6 @@ fn test_inf_recursion_function_inlining_pairs() {

for iterations in 0..10 {
let pairs = function_inlining_pairs(&program, iterations);
assert_eq!(pairs.len(), iterations);
assert_eq!(pairs.value.len(), iterations);
}
}
Loading

0 comments on commit 02b0735

Please sign in to comment.