Skip to content

Commit

Permalink
✨ No more special pred type
Browse files Browse the repository at this point in the history
  • Loading branch information
Zeta611 committed Jun 10, 2024
1 parent 0f20467 commit d14f871
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 67 deletions.
69 changes: 38 additions & 31 deletions lib/compiler.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
match peval te with
| { exp = Value v; _ } -> { ty; exp = Value (uop.op v) }
| e -> { ty; exp = Uop (uop, e) })
| If_pred (pred, te_con, te_alt) -> (
match peval_pred pred with
| True -> peval { ty; exp = If_just te_con }
| False -> peval { ty; exp = If_just te_alt }
| If_pred (te_pred, te_con, te_alt) -> (
match peval te_pred with
| { exp = Value true; _ } -> peval { ty; exp = If_just te_con }
| { exp = Value false; _ } -> peval { ty; exp = If_just te_alt }
| p -> { ty; exp = If_pred (p, peval te_con, peval te_alt) })
| Call (f, args) -> (
match peval_args args with
Expand All @@ -48,9 +48,10 @@ let rec peval : type a. (a, det) texp -> (a, det) texp =
in
{ ty; exp = Call (f_dist, []) })
| If_pred_dist (p, de) -> (
match peval_pred p with
| True -> peval de
| False -> { ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
match peval p with
| { exp = Value true; _ } -> peval de
| { exp = Value false; _ } ->
{ ty; exp = Call (Dist.one (dty_of_dist_ty ty), []) }
| p -> { ty; exp = If_pred_dist (p, peval de) })
| If_just de -> { ty; exp = If_just (peval de) }

Expand All @@ -63,23 +64,18 @@ and peval_args : type a. (a, det) args -> (a, det) args * a vargs option =
({ ty; exp = Value v } :: tl, Some ((dty_of_dat_ty ty, v) :: vargs))
| te, (tl, _) -> (te :: tl, None))

and peval_pred : pred -> pred = function
| Empty -> failwith "[Bug] Empty predicate"
| True -> True
| False -> False
| And (p, de) -> (
match peval de with
| { exp = Value true; _ } -> peval_pred p
| { exp = Value false; _ } -> False
| de -> And (p, de))
| And_not (p, de) -> (
match peval de with
| { exp = Value true; _ } -> False
| { exp = Value false; _ } -> peval_pred p
| de -> And_not (p, de))
let ( &&& ) p1 p2 : bool some_dat_det_texp =
let { ty = Dat_ty (Tyb, s1); _ } = p1 and { ty = Dat_ty (Tyb, s2); _ } = p2 in
let (Ex (ms, s)) = merge_stamps s1 s2 in
Ex
(peval
{
ty = Dat_ty (Tyb, s);
exp = Bop ({ name = "&&"; op = ( && ) }, p1, p2, ms);
})

let ( &&& ) p de = peval_pred (And (p, de))
let ( &&! ) p de = peval_pred (And_not (p, de))
let ( &&! ) p1 p2 =
p1 &&& { ty = p2.ty; exp = Uop ({ name = "not"; op = not }, p2) }

let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
function
Expand All @@ -88,9 +84,12 @@ let rec score : type a. (a dist_ty, det) texp -> (a dist_ty, det) texp =
| { exp = Call _; _ } as e -> e

let rec compile :
type a s. env:env -> ?pred:pred -> (a, ndet) texp -> Graph.t * (a, det) texp
=
fun ~env ?(pred = Empty) { ty; exp } ->
type a s.
env:env ->
pred:((bool, s) dat_ty, det) texp ->
(a, ndet) texp ->
Graph.t * (a, det) texp =
fun ~env ~pred { ty; exp } ->
match exp with
| Value _ as exp -> (Graph.empty, { ty; exp })
| Var x -> (
Expand All @@ -107,8 +106,8 @@ let rec compile :
(g, peval { ty; exp = Uop (op, te) })
| If (e_pred, e_con, e_alt, _, _) ->
let g1, de_pred = compile ~env ~pred e_pred in
let pred_con = pred &&& de_pred in
let pred_alt = pred &&! de_pred in
let (Ex pred_con) = pred &&& de_pred in
let (Ex pred_alt) = pred &&! de_pred in
let g2, de_con = compile ~env ~pred:pred_con e_con in
let g3, de_alt = compile ~env ~pred:pred_alt e_alt in
let g = Graph.(g1 @| g2 @| g3) in
Expand Down Expand Up @@ -143,7 +142,7 @@ let rec compile :
let v = gen_vertex () in
let f1 = score de1 in
let f = { ty = f1.ty; exp = If_pred_dist (pred, f1) } in
let fvs = Id.(fv de1.exp @| fv_pred pred) in
let fvs = Id.(fv de1.exp @| fv pred.exp) in
if not (Set.is_empty (fv de2.exp)) then
failwith "[Bug] Not closed observation";
let g' =
Expand All @@ -158,7 +157,11 @@ let rec compile :
Graph.(g1 @| g2 @| g', { ty = Dat_ty (Tyu, Val); exp = Value () })

and compile_args :
type a. env -> pred -> (a, ndet) args -> Graph.t * (a, det) args =
type a s.
env ->
((bool, s) dat_ty, det) texp ->
(a, ndet) args ->
Graph.t * (a, det) args =
fun env pred args ->
match args with
| [] -> (Graph.empty, [])
Expand All @@ -177,7 +180,11 @@ let compile_program (prog : program) : Graph.t * Evaluator.query =
m "Inlined program %a" Sexp.pp_hum [%sexp (exp : Parse_tree.exp)]);

let (Ex e) = Typing.check exp in
let g, { ty; exp } = compile ~env:Id.Map.empty e in
let g, { ty; exp } =
compile ~env:Id.Map.empty
~pred:{ ty = Dat_ty (Tyb, Val); exp = Value true }
e
in
match ty with
| Dat_ty (_, Rv) -> (g, Ex { ty; exp })
| _ -> raise Query_not_found
14 changes: 4 additions & 10 deletions lib/evaluator.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,18 @@ let rec eval_dat : type a s. Ctx.t -> ((a, s) dat_ty, det) texp -> a =
| None -> assert false)
| Bop ({ op; _ }, te1, te2, _) -> op (eval_dat ctx te1) (eval_dat ctx te2)
| Uop ({ op; _ }, te) -> op (eval_dat ctx te)
| If_pred (pred, te_con, te_alt) ->
if eval_pred ctx pred then eval_dat ctx te_con else eval_dat ctx te_alt
| If_pred (te_pred, te_con, te_alt) ->
if eval_dat ctx te_pred then eval_dat ctx te_con else eval_dat ctx te_alt
| If_just te -> eval_dat ctx te

and eval_dist : type a. Ctx.t -> (a dist_ty, det) texp -> a =
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| Call (f, args) -> f.sampler (eval_args ctx args)
| If_pred_dist (pred, dist) ->
if eval_pred ctx pred then eval_dist ctx dist
if eval_dat ctx pred then eval_dist ctx dist
else eval_dist ctx { ty; exp = Call (Dist.one dty, []) }

and eval_pred (ctx : Ctx.t) : pred -> bool = function
| Empty | True -> true
| False -> false
| And (p, de) -> eval_dat ctx de && eval_pred ctx p
| And_not (p, de) -> (not (eval_dat ctx de)) && eval_pred ctx p

and eval_args : type a. Ctx.t -> (a, det) args -> a vargs =
fun ctx -> function
| [] -> []
Expand All @@ -51,7 +45,7 @@ let rec eval_pmdf :
fun ctx { ty = Dist_ty dty as ty; exp } ->
match exp with
| If_pred_dist (pred, te) ->
if eval_pred ctx pred then eval_pmdf ctx te
if eval_dat ctx pred then eval_pmdf ctx te
else eval_pmdf ctx { ty; exp = Call (Dist.one dty, []) }
| Call (f, args) ->
let pmdf (Ex (ty', v) : some_val) =
Expand Down
41 changes: 15 additions & 26 deletions lib/typed_tree.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,10 @@ type ('a, 'b) dist = {
log_pmdf : 'b vargs -> 'a -> real;
}

(* TODO: Why args should also be det? *)
type (_, _) args =
| [] : (unit, _) args
| ( :: ) : (('a, _) dat_ty, 'd) texp * ('b, 'd) args -> ('a * 'b, 'd) args

and pred =
| Empty : pred
| True : pred
| False : pred
| And : pred * ((bool, _) dat_ty, det) texp -> pred
| And_not : pred * ((bool, _) dat_ty, det) texp -> pred

and ('a, 'd) texp = { ty : 'a ty; exp : ('a, 'd) exp }

and (_, _) exp =
Expand All @@ -73,9 +65,13 @@ and (_, _) exp =
* ('s_pred, 's_ca, 's) merge_stamp
-> (('a, 's) dat_ty, ndet) exp
| If_pred :
pred * (('a, _) dat_ty, det) texp * (('a, _) dat_ty, det) texp
((bool, _) dat_ty, det) texp
* (('a, _) dat_ty, det) texp
* (('a, _) dat_ty, det) texp
-> (('a, _) dat_ty, det) exp
| If_pred_dist : pred * ('a dist_ty, det) texp -> ('a dist_ty, det) exp
| If_pred_dist :
((bool, _) dat_ty, det) texp * ('a dist_ty, det) texp
-> ('a dist_ty, det) exp
| If_just : (('a, _) dat_ty, det) texp -> (('a, _) dat_ty, det) exp
| Let : Id.t * ('a, ndet) texp * ('b, ndet) texp -> ('b, ndet) exp
| Call : ('a, 'b) dist * ('b, 'd) args -> ('a dist_ty, 'd) exp
Expand All @@ -92,6 +88,9 @@ type _ some_texp = Ex : (_, 'd) texp -> 'd some_texp
type _ some_dat_ndet_texp =
| Ex : (('a, _) dat_ty, ndet) texp -> 'a some_dat_ndet_texp

type _ some_dat_det_texp =
| Ex : (('a, _) dat_ty, det) texp -> 'a some_dat_det_texp

type _ some_val_texp = Ex : ((_, value) dat_ty, 'd) texp -> 'd some_val_texp
type _ some_rv_texp = Ex : ((_, rv) dat_ty, 'd) texp -> 'd some_rv_texp
type _ some_dat_texp = Ex : (_ dat_ty, 'd) texp -> 'd some_dat_texp
Expand Down Expand Up @@ -184,21 +183,17 @@ let rec fv : type a. (a, det) exp -> Id.Set.t = function
| Rvar x -> Id.Set.singleton x
| Bop (_, { exp = e1; _ }, { exp = e2; _ }, _) -> Id.(fv e1 @| fv e2)
| Uop (_, { exp; _ }) -> fv exp
| If_pred (pred, { exp = e_con; _ }, { exp = e_alt; _ }) ->
Id.(fv_pred pred @| fv e_con @| fv e_alt)
| If_pred_dist (pred, { exp = e_con; _ }) -> Id.(fv_pred pred @| fv e_con)
| If_pred ({ exp = e_pred; _ }, { exp = e_con; _ }, { exp = e_alt; _ }) ->
Id.(fv e_pred @| fv e_con @| fv e_alt)
| If_pred_dist ({ exp = e_pred; _ }, { exp = e_con; _ }) ->
Id.(fv e_pred @| fv e_con)
| If_just { exp; _ } -> fv exp
| Call (_, args) -> fv_args args

and fv_args : type a. (a, det) args -> Id.Set.t = function
| [] -> Id.Set.empty
| { exp; _ } :: es -> Id.(fv exp @| fv_args es)

and fv_pred : pred -> Id.Set.t = function
| Empty | True | False -> Id.Set.empty
| And (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)
| And_not (p, { exp = de; _ }) -> Id.(fv de @| fv_pred p)

module Erased = struct
type exp =
| Value : string -> exp
Expand All @@ -220,8 +215,8 @@ module Erased = struct
fun { ty; exp } ->
match exp with
| If (pred, con, alt, _, _) -> If (of_exp pred, of_exp con, of_exp alt)
| If_pred (pred, con, alt) -> If (of_pred pred, of_exp con, of_exp alt)
| If_pred_dist (pred, con) -> If (of_pred pred, of_exp con, Value "1")
| If_pred (pred, con, alt) -> If (of_exp pred, of_exp con, of_exp alt)
| If_pred_dist (pred, con) -> If (of_exp pred, of_exp con, Value "1")
| If_just exp -> If_just (of_exp exp)
| Value v -> (
match ty with
Expand All @@ -241,11 +236,5 @@ module Erased = struct
| [] -> []
| arg :: args -> of_exp arg :: of_args args

and of_pred : pred -> exp = function
| Empty | True -> Value "true"
| False -> Value "false"
| And (pred, exp) -> Bop ("&&", of_pred pred, of_exp exp)
| And_not (pred, exp) -> Bop ("&&", of_pred pred, Uop ("not", of_exp exp))

let of_rv (Ex rv : _ some_rv_texp) = rv |> of_exp
end

0 comments on commit d14f871

Please sign in to comment.