Skip to content

Commit

Permalink
Remove LoopContextUnionsAnd, just pass ContextCache around (fixes sub…
Browse files Browse the repository at this point in the history
…st bug)
  • Loading branch information
Alex-Fischman committed Jun 3, 2024
1 parent 64d2186 commit f551d04
Show file tree
Hide file tree
Showing 10 changed files with 184 additions and 249 deletions.
107 changes: 38 additions & 69 deletions dag_in_context/src/add_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ use crate::{
pub struct ContextCache {
with_ctx: HashMap<(*const Expr, AssumptionRef), RcExpr>,
symbol_gen: HashMap<(*const Expr, AssumptionRef), String>,
loop_contexts: LoopContextUnionsAnd<()>,
/// Information for generating the unions for loop contexts
placeholder: usize,
loop_context_unions: Vec<(Assumption, Assumption)>,
/// When true, don't add context- instead, make fresh query variables
/// and put these in place of context
symbolic_ctx: bool,
Expand Down Expand Up @@ -46,7 +48,8 @@ impl ContextCache {
ContextCache {
with_ctx: HashMap::new(),
symbol_gen: HashMap::new(),
loop_contexts: LoopContextUnionsAnd::new(),
placeholder: 0,
loop_context_unions: Vec::new(),
symbolic_ctx: false,
dummy_ctx: false,
}
Expand All @@ -56,7 +59,8 @@ impl ContextCache {
ContextCache {
with_ctx: HashMap::new(),
symbol_gen: HashMap::new(),
loop_contexts: LoopContextUnionsAnd::new(),
placeholder: 0,
loop_context_unions: Vec::new(),
symbolic_ctx: true,
dummy_ctx: false,
}
Expand All @@ -66,61 +70,17 @@ impl ContextCache {
ContextCache {
with_ctx: HashMap::new(),
symbol_gen: HashMap::new(),
loop_contexts: LoopContextUnionsAnd::new(),
placeholder: 0,
loop_context_unions: Vec::new(),
symbolic_ctx: false,
dummy_ctx: true,
}
}

pub fn get_unions(&self) -> String {
self.loop_contexts.get_unions()
}
}

// not a tuple to prevent auto-impls of Clone, Debug, etc.
pub struct LoopContextUnionsAnd<T> {
var: usize,
// marked as public but you probably want `get_unions`
pub unions: Vec<(Assumption, Assumption)>,
pub value: T,
}

impl Default for LoopContextUnionsAnd<()> {
fn default() -> Self {
Self::new()
}
}

impl LoopContextUnionsAnd<()> {
pub fn new() -> LoopContextUnionsAnd<()> {
LoopContextUnionsAnd {
var: 0,
unions: Vec::new(),
value: (),
}
}
}

impl<T> LoopContextUnionsAnd<T> {
fn new_placeholder(&mut self) -> Assumption {
let placeholder = Assumption::InFunc(format!(" loop_ctx_{}", self.var));
self.var += 1;
placeholder
}

pub fn swap_value<S>(self, value: S) -> (LoopContextUnionsAnd<S>, T) {
let LoopContextUnionsAnd {
var,
unions,
value: old,
} = self;
(LoopContextUnionsAnd { var, unions, value }, old)
}

pub fn get_unions(&self) -> String {
use std::fmt::Write;

self.unions
self.loop_context_unions
.iter()
.fold(String::new(), |mut output, (a, b)| {
let _ = writeln!(output, "(union {a} {b})");
Expand All @@ -134,7 +94,7 @@ impl<T> LoopContextUnionsAnd<T> {
tree_state: &mut TreeToEgglog,
term_cache: &mut HashMap<Term, String>,
) -> String {
self.unions
self.loop_context_unions
.iter()
.map(|(a, b)| {
let internal_a = a.to_egglog_internal(tree_state);
Expand All @@ -158,36 +118,47 @@ impl<T> LoopContextUnionsAnd<T> {
.collect::<Vec<_>>()
.join("\n")
}

pub fn new_placeholder(&mut self) -> Assumption {
let placeholder = Assumption::InFunc(format!(" loop_ctx_{}", self.placeholder));
self.placeholder += 1;
placeholder
}

pub fn push_loop_context_union(&mut self, a: Assumption, b: Assumption) {
self.loop_context_unions.push((a, b))
}
}

impl TreeProgram {
pub fn add_context(&self) -> LoopContextUnionsAnd<TreeProgram> {
pub fn add_context(&self) -> (TreeProgram, ContextCache) {
self.add_context_internal(Expr::func_get_ctx, ContextCache::new())
}

/// add stand-in variables for all the contexts in the program
/// useful for testing if you don't care about context in the test
pub fn add_symbolic_ctx(&self) -> LoopContextUnionsAnd<TreeProgram> {
pub fn add_symbolic_ctx(&self) -> TreeProgram {
self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_symbolic_ctx())
.0
}

pub fn add_dummy_ctx(&self) -> LoopContextUnionsAnd<TreeProgram> {
pub fn add_dummy_ctx(&self) -> TreeProgram {
self.add_context_internal(|_| Assumption::dummy(), ContextCache::new_dummy_ctx())
.0
}

fn add_context_internal(
&self,
func: impl Fn(&RcExpr) -> Assumption,
mut cache: ContextCache,
) -> LoopContextUnionsAnd<TreeProgram> {
) -> (TreeProgram, ContextCache) {
let entry = self.entry.add_ctx_with_cache(func(&self.entry), &mut cache);
let functions = self
.functions
.iter()
.map(|f| f.add_ctx_with_cache(func(f), &mut cache))
.collect();
let value = TreeProgram { functions, entry };
cache.loop_contexts.swap_value(value).0
(TreeProgram { functions, entry }, cache)
}
}

Expand All @@ -199,7 +170,7 @@ impl Expr {
Assumption::InFunc(name.clone())
}

pub fn func_add_ctx(self: &RcExpr) -> LoopContextUnionsAnd<RcExpr> {
pub fn func_add_ctx(self: &RcExpr) -> (RcExpr, ContextCache) {
let Expr::Function(name, arg_ty, ret_ty, body) = self.as_ref() else {
panic!("Expected Function, got {:?}", self);
};
Expand All @@ -210,25 +181,23 @@ impl Expr {
ret_ty.clone(),
body.add_ctx_with_cache(self.func_get_ctx(), &mut cache),
));
cache.loop_contexts.swap_value(value).0
(value, cache)
}

pub fn add_dummy_ctx(self: &RcExpr) -> LoopContextUnionsAnd<RcExpr> {
pub fn add_dummy_ctx(self: &RcExpr) -> RcExpr {
let mut cache = ContextCache::new_dummy_ctx();
let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache);
cache.loop_contexts.swap_value(value).0
self.add_ctx_with_cache(Assumption::dummy(), &mut cache)
}

pub fn add_symbolic_ctx(self: &RcExpr) -> LoopContextUnionsAnd<RcExpr> {
pub fn add_symbolic_ctx(self: &RcExpr) -> RcExpr {
let mut cache = ContextCache::new_symbolic_ctx();
let value = self.add_ctx_with_cache(Assumption::dummy(), &mut cache);
cache.loop_contexts.swap_value(value).0
self.add_ctx_with_cache(Assumption::dummy(), &mut cache)
}

pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> LoopContextUnionsAnd<RcExpr> {
pub fn add_ctx(self: &RcExpr, current_ctx: Assumption) -> (RcExpr, ContextCache) {
let mut cache = ContextCache::new();
let value = self.add_ctx_with_cache(current_ctx, &mut cache);
cache.loop_contexts.swap_value(value).0
(value, cache)
}

pub fn add_ctx_with_cache(
Expand Down Expand Up @@ -259,14 +228,14 @@ impl Expr {
Expr::Arg(ty, _oldctx) => RcExpr::new(Expr::Arg(ty.clone(), context_to_add)),
// create new contexts for let, loop, and if
Expr::DoWhile(inputs, pred_and_body) => {
let placeholder = cache.loop_contexts.new_placeholder();
let placeholder = cache.new_placeholder();

let new_inputs = inputs.add_ctx_with_cache(current_ctx.clone(), cache);
let new_pred_and_body =
pred_and_body.add_ctx_with_cache(placeholder.clone(), cache);

let new_ctx = Assumption::InLoop(new_inputs.clone(), new_pred_and_body.clone());
cache.loop_contexts.unions.push((placeholder, new_ctx));
cache.push_loop_context_union(placeholder, new_ctx);

RcExpr::new(Expr::DoWhile(new_inputs, new_pred_and_body))
}
Expand Down
57 changes: 26 additions & 31 deletions dag_in_context/src/interval_analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,23 +79,20 @@ fn constant_fold() -> crate::Result {
#[test]
fn test_add_constant_fold() -> crate::Result {
use crate::ast::*;
let expr = add(int(1), int(2))
let (expr, cache) = add(int(1), int(2))
.with_arg_types(emptyt(), base(intt()))
.add_ctx(Assumption::dummy());
let expr2 = int_ty(3, emptyt()).add_ctx(Assumption::dummy());
let (expr2, cache2) = int_ty(3, emptyt()).add_ctx(Assumption::dummy());

egglog_test(
&format!("{}\n{}", expr.value, expr.get_unions()),
&format!("{expr}\n{}", cache.get_unions()),
&format!(
"{}\n{}\n(check (= {} {}))",
expr2.value,
expr2.get_unions(),
expr.value,
expr2.value
"{expr2}\n{}\n(check (= {expr} {expr2}))",
cache2.get_unions(),
),
vec![
expr.value.to_program(emptyt(), base(intt())),
expr2.value.to_program(emptyt(), base(intt())),
expr.to_program(emptyt(), base(intt())),
expr2.to_program(emptyt(), base(intt())),
],
val_empty(),
intv(3),
Expand Down Expand Up @@ -198,13 +195,13 @@ fn context_if() -> crate::Result {

let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types();
let prog = f.to_program(base(intt()), base(boolt()));
let with_context = prog.add_context();
let term = with_context.value.entry.func_body().unwrap();
let (with_context, cache) = prog.add_context();
let term = with_context.entry.func_body().unwrap();

egglog_test(
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!("{with_context}\n{}", cache.get_unions()),
&format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"),
vec![with_context.value],
vec![with_context],
intv(4),
val_bool(false),
vec![],
Expand All @@ -216,13 +213,13 @@ fn simple_less_than() -> crate::Result {
// 0 <= input
let cond = less_eq(int_ty(0, base(intt())), int(-1));
let prog = program!(function("main", base(intt()), base(boolt()), cond.clone()),);
let with_context = prog.add_context();
let term = with_context.value.entry.func_body().unwrap();
let (with_context, cache) = prog.add_context();
let term = with_context.entry.func_body().unwrap();

egglog_test(
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!("{with_context}\n{}", cache.get_unions()),
&format!("(check (= {term} (Const (Bool false) (Base (IntT)) somectx)))"),
vec![with_context.value],
vec![with_context],
intv(4),
val_bool(false),
vec![],
Expand All @@ -244,16 +241,16 @@ fn context_if_rev() -> crate::Result {

let f = function("main", base(intt()), base(boolt()), z.clone()).func_with_arg_types();
let prog = f.to_program(base(intt()), base(boolt()));
let with_context = prog.add_context();
let term = with_context.value.entry.func_body().unwrap();
let (with_context, cache) = prog.add_context();
let term = with_context.entry.func_body().unwrap();

egglog_test(
&format!("{}\n{}", with_context.value, with_context.get_unions()),
&format!("{with_context}\n{}", cache.get_unions()),
&format!(
"
(check (= {term} (Const (Bool false) (Base (IntT)) (InFunc \"main\"))))"
),
vec![with_context.value],
vec![with_context],
intv(4),
val_bool(false),
vec![],
Expand Down Expand Up @@ -303,28 +300,26 @@ fn context_if_with_state() -> crate::Result {
body.clone(),
)
.func_with_arg_types();
let prog = f
let (prog, cache) = f
.to_program(input_type.clone(), output_type.clone())
.with_arg_types()
.add_context();

let body_with_ctx = prog.value.entry.func_body().unwrap();
let body_with_ctx = prog.entry.func_body().unwrap();

let expected = single(tprint(
let (expected, expected_cache) = single(tprint(
ttrue_ty(input_type.clone()),
get(input_arg.clone(), 1),
))
.add_ctx(Assumption::InFunc("main".to_string()));

egglog_test(
&format!("{}\n{}", prog.value, prog.get_unions()),
&format!("{prog}\n{}", cache.get_unions()),
&format!(
"{}\n{}\n(check (= {} {body_with_ctx}))",
expected.value,
expected.get_unions(),
expected.value
"{expected}\n{}\n(check (= {expected} {body_with_ctx}))",
expected_cache.get_unions(),
),
vec![prog.value],
vec![prog],
val_vec(vec![intv(4), statev()]),
val_vec(vec![statev()]),
vec!["true".to_string()],
Expand Down
Loading

0 comments on commit f551d04

Please sign in to comment.