diff --git a/lib/compiler.ml b/lib/compiler.ml index 4b391f7..ecd66ff 100644 --- a/lib/compiler.ml +++ b/lib/compiler.ml @@ -25,10 +25,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp = match peval te with | { exp = Value v; _ } -> { ty; exp = Value (uop.op v) } | e -> { ty; exp = Uop (uop, e) }) - | If_pred (pred, te_con, te_alt) -> ( - match peval_pred pred with - | True -> peval { ty; exp = If_just te_con } - | False -> peval { ty; exp = If_just te_alt } + | If_pred (te_pred, te_con, te_alt) -> ( + match peval te_pred with + | { exp = Value true; _ } -> peval { ty; exp = If_just te_con } + | { exp = Value false; _ } -> peval { ty; exp = If_just te_alt } | p -> { ty; exp = If_pred (p, peval te_con, peval te_alt) }) | Call (f, args) -> ( match peval_args args with @@ -48,9 +48,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp = in { ty; exp = Call (f_dist, []) }) | If_pred_dist (p, de) -> ( - match peval_pred p with - | True -> peval de - | False -> { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) } + match peval p with + | { exp = Value true; _ } -> peval de + | { exp = Value false; _ } -> + { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) } | p -> { ty; exp = If_pred_dist (p, peval de) }) | If_just de -> { ty; exp = If_just (peval de) } @@ -63,23 +64,18 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option = ({ ty; exp = Value v } :: tl, Some ((dty_of_dat_ty ty, v) :: vargs)) | te, (tl, _) -> (te :: tl, None)) -and peval_pred : pred -> pred = function - | Empty -> failwith "[Bug] Empty predicate" - | True -> True - | False -> False - | And (p, de) -> ( - match peval de with - | { exp = Value true; _ } -> peval_pred p - | { exp = Value false; _ } -> False - | de -> And (p, de)) - | And_not (p, de) -> ( - match peval de with - | { exp = Value true; _ } -> False - | { exp = Value false; _ } -> peval_pred p - | de -> And_not (p, de)) +let ( &&& ) p1 p2 : bool some_dat_det_texp = + let { ty = Dat_ty (Tyb, s1); _ } = p1 and { ty = Dat_ty (Tyb, s2); _ } = p2 in + let (Ex (ms, s)) = merge_stamps s1 s2 in + Ex + (peval + { + ty = Dat_ty (Tyb, s); + exp = Bop ({ name = "&&"; op = ( && ) }, p1, p2, ms); + }) -let ( &&& ) p de = peval_pred (And (p, de)) -let ( &&! ) p de = peval_pred (And_not (p, de)) +let ( &&! ) p1 p2 = + p1 &&& { ty = p2.ty; exp = Uop ({ name = "not"; op = not }, p2) } let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp = function @@ -88,9 +84,12 @@ let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp = | { exp = Call _; _ } as e -> e let rec compile : - type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp - = - fun ~env ?(pred = Empty) { ty; exp } -> + type a s. + env:env -> + pred:((bool, s) dat_ty, det) texp -> + (a, ndet) texp -> + Graph.t * (a, det) texp = + fun ~env ~pred { ty; exp } -> match exp with | Value _ as exp -> (Graph.empty, { ty; exp }) | Var x -> ( @@ -107,8 +106,8 @@ let rec compile : (g, peval { ty; exp = Uop (op, te) }) | If (e_pred, e_con, e_alt, _, _) -> let g1, de_pred = compile ~env ~pred e_pred in - let pred_con = pred &&& de_pred in - let pred_alt = pred &&! de_pred in + let (Ex pred_con) = pred &&& de_pred in + let (Ex pred_alt) = pred &&! de_pred in let g2, de_con = compile ~env ~pred:pred_con e_con in let g3, de_alt = compile ~env ~pred:pred_alt e_alt in let g = Graph.(g1 @| g2 @| g3) in @@ -143,7 +142,7 @@ let rec compile : let v = gen_vertex () in let f1 = score de1 in let f = { ty = f1.ty; exp = If_pred_dist (pred, f1) } in - let fvs = Id.(fv de1.exp @| fv_pred pred) in + let fvs = Id.(fv de1.exp @| fv pred.exp) in if not (Set.is_empty (fv de2.exp)) then failwith "[Bug] Not closed observation"; let g' = @@ -158,7 +157,11 @@ let rec compile : Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () }) and compile_args : - type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args = + type a s. + env -> + ((bool, s) dat_ty, det) texp -> + (a, ndet) args -> + Graph.t * (a, det) args = fun env pred args -> match args with | [] -> (Graph.empty, []) @@ -177,7 +180,11 @@ let compile_program (prog : program) : Graph.t * Evaluator.query = m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]); let (Ex e) = Typing.check exp in - let g, { ty; exp } = compile ~env:Id.Map.empty e in + let g, { ty; exp } = + compile ~env:Id.Map.empty + ~pred:{ ty = Dat_ty (Tyb, Val); exp = Value true } + e + in match ty with | Dat_ty (_, Rv) -> (g, Ex { ty; exp }) | _ -> raise Query_not_found diff --git a/lib/evaluator.ml b/lib/evaluator.ml index ca859e1..8a6d764 100644 --- a/lib/evaluator.ml +++ b/lib/evaluator.ml @@ -22,8 +22,8 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a = | None -> assert false) | Bop ({ op; _ }, te1, te2, _) -> op (eval_dat ctx te1) (eval_dat ctx te2) | Uop ({ op; _ }, te) -> op (eval_dat ctx te) - | If_pred (pred, te_con, te_alt) -> - if eval_pred ctx pred then eval_dat ctx te_con else eval_dat ctx te_alt + | If_pred (te_pred, te_con, te_alt) -> + if eval_dat ctx te_pred then eval_dat ctx te_con else eval_dat ctx te_alt | If_just te -> eval_dat ctx te and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a = @@ -31,15 +31,9 @@ and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a = match exp with | Call (f, args) -> f.sampler (eval_args ctx args) | If_pred_dist (pred, dist) -> - if eval_pred ctx pred then eval_dist ctx dist + if eval_dat ctx pred then eval_dist ctx dist else eval_dist ctx { ty; exp = Call (Dist.one dty, []) } -and eval_pred (ctx : Ctx.t) : pred -> bool = function - | Empty | True -> true - | False -> false - | And (p, de) -> eval_dat ctx de && eval_pred ctx p - | And_not (p, de) -> (not (eval_dat ctx de)) && eval_pred ctx p - and eval_args : type a. Ctx.t -> (a, det) args -> a vargs = fun ctx -> function | [] -> [] @@ -51,7 +45,7 @@ let rec eval_pmdf : fun ctx { ty = Dist_ty dty as ty; exp } -> match exp with | If_pred_dist (pred, te) -> - if eval_pred ctx pred then eval_pmdf ctx te + if eval_dat ctx pred then eval_pmdf ctx te else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) } | Call (f, args) -> let pmdf (Ex (ty', v) : some_val) = diff --git a/lib/typed_tree.ml b/lib/typed_tree.ml index f59ed03..6625bb4 100644 --- a/lib/typed_tree.ml +++ b/lib/typed_tree.ml @@ -40,18 +40,10 @@ type ('a, 'b) dist = { log_pmdf : 'b vargs -> 'a -> real; } -(* TODO: Why args should also be det? *) type (_, _) args = | [] : (unit, _) args | ( :: ) : (('a, _) dat_ty, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args -and pred = - | Empty : pred - | True : pred - | False : pred - | And : pred * ((bool, _) dat_ty, det) texp -> pred - | And_not : pred * ((bool, _) dat_ty, det) texp -> pred - and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp } and (_, _) exp = @@ -73,9 +65,13 @@ and (_, _) exp = * ('s_pred, 's_ca, 's) merge_stamp -> (('a, 's) dat_ty, ndet) exp | If_pred : - pred * (('a, _) dat_ty, det) texp * (('a, _) dat_ty, det) texp + ((bool, _) dat_ty, det) texp + * (('a, _) dat_ty, det) texp + * (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp - | If_pred_dist : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp + | If_pred_dist : + ((bool, _) dat_ty, det) texp * ('a dist_ty, det) texp + -> ('a dist_ty, det) exp | If_just : (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp | Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp | Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp @@ -92,6 +88,9 @@ type _ some_texp = Ex : (_, 'd) texp -> 'd some_texp type _ some_dat_ndet_texp = | Ex : (('a, _) dat_ty, ndet) texp -> 'a some_dat_ndet_texp +type _ some_dat_det_texp = + | Ex : (('a, _) dat_ty, det) texp -> 'a some_dat_det_texp + type _ some_val_texp = Ex : ((_, value) dat_ty, 'd) texp -> 'd some_val_texp type _ some_rv_texp = Ex : ((_, rv) dat_ty, 'd) texp -> 'd some_rv_texp type _ some_dat_texp = Ex : (_ dat_ty, 'd) texp -> 'd some_dat_texp @@ -184,9 +183,10 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function | Rvar x -> Id.Set.singleton x | Bop (_, { exp = e1; _ }, { exp = e2; _ }, _) -> Id.(fv e1 @| fv e2) | Uop (_, { exp; _ }) -> fv exp - | If_pred (pred, { exp = e_con; _ }, { exp = e_alt; _ }) -> - Id.(fv_pred pred @| fv e_con @| fv e_alt) - | If_pred_dist (pred, { exp = e_con; _ }) -> Id.(fv_pred pred @| fv e_con) + | If_pred ({ exp = e_pred; _ }, { exp = e_con; _ }, { exp = e_alt; _ }) -> + Id.(fv e_pred @| fv e_con @| fv e_alt) + | If_pred_dist ({ exp = e_pred; _ }, { exp = e_con; _ }) -> + Id.(fv e_pred @| fv e_con) | If_just { exp; _ } -> fv exp | Call (_, args) -> fv_args args @@ -194,11 +194,6 @@ and fv_args : type a. (a, det) args -> Id.Set.t = function | [] -> Id.Set.empty | { exp; _ } :: es -> Id.(fv exp @| fv_args es) -and fv_pred : pred -> Id.Set.t = function - | Empty | True | False -> Id.Set.empty - | And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) - | And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p) - module Erased = struct type exp = | Value : string -> exp @@ -220,8 +215,8 @@ module Erased = struct fun { ty; exp } -> match exp with | If (pred, con, alt, _, _) -> If (of_exp pred, of_exp con, of_exp alt) - | If_pred (pred, con, alt) -> If (of_pred pred, of_exp con, of_exp alt) - | If_pred_dist (pred, con) -> If (of_pred pred, of_exp con, Value "1") + | If_pred (pred, con, alt) -> If (of_exp pred, of_exp con, of_exp alt) + | If_pred_dist (pred, con) -> If (of_exp pred, of_exp con, Value "1") | If_just exp -> If_just (of_exp exp) | Value v -> ( match ty with @@ -241,11 +236,5 @@ module Erased = struct | [] -> [] | arg :: args -> of_exp arg :: of_args args - and of_pred : pred -> exp = function - | Empty | True -> Value "true" - | False -> Value "false" - | And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp) - | And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp)) - let of_rv (Ex rv : _ some_rv_texp) = rv |> of_exp end