Skip to content

Commit

Permalink
Redo substitutions and binder representation for safer usage
Browse files Browse the repository at this point in the history
  • Loading branch information
msprotz committed Dec 7, 2023
1 parent b221a8c commit 6156d7d
Show file tree
Hide file tree
Showing 13 changed files with 96 additions and 66 deletions.
6 changes: 3 additions & 3 deletions lib/Ast.ml
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ and files =
and decl =
| DFunction of calling_convention option * flag list * int * int * typ * lident * binders_w * expr_w
| DGlobal of flag list * lident * int * typ * expr_w
| DExternal of calling_convention option * flag list * int * lident * typ * string list
| DExternal of calling_convention option * flag list * int * int * lident * typ * string list
(** String list: only for pretty-printing purposes, names of the first few
* known arguments. *)
| DType of lident * flag list * int * int * type_def
Expand Down Expand Up @@ -639,15 +639,15 @@ let with_type typ node =
let lid_of_decl = function
| DFunction (_, _, _, _, _, lid, _, _)
| DGlobal (_, lid, _, _, _)
| DExternal (_, _, _, lid, _, _)
| DExternal (_, _, _, _, lid, _, _)
| DType (lid, _, _, _, _) ->
lid

let flags_of_decl = function
| DFunction (_, flags, _, _, _, _, _, _)
| DGlobal (flags, _, _, _, _)
| DType (_, flags, _, _, _)
| DExternal (_, flags, _, _, _, _) ->
| DExternal (_, flags, _, _, _, _, _) ->
flags

let tuple_lid = [ "K" ], ""
2 changes: 1 addition & 1 deletion lib/AstToCFlat.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1078,7 +1078,7 @@ let mk_decl env (d: decl): env * CF.decl list =
CF.Global (name, size, body, post_init, public)
]

| DExternal (_, _, _, lid, t, _) ->
| DExternal (_, _, _, _, lid, t, _) ->
let name = GlobalNames.to_c_name env.names lid in
match t with
| TArrow _ ->
Expand Down
10 changes: 4 additions & 6 deletions lib/AstToCStar.ml
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,10 @@ and mk_expr env in_stmt e =
CStar.Constant c

| EApp ({ node = ETApp (e, cgs, ts); _ }, es) when !Options.allow_tapps || whitelisted_tapp e ->
assert (cgs = []);
unit_to_void env e es (List.map (fun t -> CStar.Type (mk_type env t)) ts)
unit_to_void env e (cgs @ es) (List.map (fun t -> CStar.Type (mk_type env t)) ts)

| ETApp (e, cgs, ts) when !Options.allow_tapps || whitelisted_tapp e ->
assert (cgs = []);
CStar.Call (mk_expr env e, List.map (fun t -> CStar.Type (mk_type env t)) ts)
unit_to_void env e cgs (List.map (fun t -> CStar.Type (mk_type env t)) ts)

| EApp ({ node = EOp (op, w); _ }, [ _; _ ]) when is_arith op w ->
fst (mk_arith env e)
Expand Down Expand Up @@ -903,7 +901,7 @@ and mk_declaration m env d: (CStar.decl * _) option =
mk_type env t,
mk_expr env false body), [])

| DExternal (cc, flags, n, name, t, pp) ->
| DExternal (cc, flags, _, n, name, t, pp) ->
if LidSet.mem name env.ifdefs || n > 0 then
None
else
Expand Down Expand Up @@ -1006,7 +1004,7 @@ let mk_ifdefs_set files: LidSet.t =
inherit [_] reduce as super
method private zero = LidSet.empty
method private plus = LidSet.union
method! visit_DExternal _env _cc (flags: flag list) _n (name: lident) _t _hints: LidSet.t =
method! visit_DExternal _env _cc (flags: flag list) _n_cg _n (name: lident) _t _hints: LidSet.t =
if List.mem Common.IfDef flags then
LidSet.singleton name
else
Expand Down
2 changes: 1 addition & 1 deletion lib/AstToMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -823,7 +823,7 @@ let translate_decl env (d: Ast.decl) =
let env = push_global env lid (name, typ) in
env, Some (MiniRust.Constant { name; typ; body; public })

| Ast.DExternal (_, _, type_parameters, lid, t, _param_names) ->
| Ast.DExternal (_, _, _, type_parameters, lid, t, _param_names) ->
let name = translate_unknown_lid lid in
let env = push_global env lid (name, make_poly (translate_type env t) type_parameters) in
env, None
Expand Down
11 changes: 6 additions & 5 deletions lib/Builtin.ml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ open Helpers
let t_string = TQualified (["Prims"], "string")

let mk_binop m n t =
DExternal (None, [ ], 0, (m, n), TArrow (t, TArrow (t, t)), [ "x"; "y" ])
DExternal (None, [ ], 0, 0, (m, n), TArrow (t, TArrow (t, t)), [ "x"; "y" ])

let mk_val ?(flags=[]) ?(nvars=0) m n t =
DExternal (None, flags, nvars, (m, n), t, [])
DExternal (None, flags, 0, nvars, (m, n), t, [])

let prims: file =
let t = TInt K.CInt in
Expand Down Expand Up @@ -351,15 +351,16 @@ let addendum = [
]

let make_abstract_function_or_global = function
| DFunction (cc, flags, _, n, t, name, bs, _) ->
| DFunction (cc, flags, n_cg, n, t, name, bs, _) ->
let t = fold_arrow (List.map (fun b -> b.typ) bs) t in
assert (n_cg = 0);
if n = 0 then
Some (DExternal (cc, flags, 0, name, t, List.map (fun x -> x.node.name) bs))
Some (DExternal (cc, flags, 0, 0, name, t, List.map (fun x -> x.node.name) bs))
else
None
| DGlobal (flags, name, n, t, _) when not (List.mem Common.Macro flags) ->
if n = 0 then
Some (DExternal (None, flags, 0, name, t, []))
Some (DExternal (None, flags, 0, 0, name, t, []))
else
None
| DType (name, flags, _, _, _) when List.mem Common.AbstractStruct flags ->
Expand Down
4 changes: 2 additions & 2 deletions lib/Bundles.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ let make_one_bundle (bundle: Bundle.t) (files: file list) (used: (int * Bundle.t
DGlobal (add_if name flags, name, n, typ, body)
| DType (lid, flags, n_cgs, n, def) ->
DType (lid, add_if lid flags, n_cgs, n, def)
| DExternal (cc, flags, n, lid, t, pp) ->
DExternal (cc, add_if lid flags, n, lid, t, pp)
| DExternal (cc, flags, n_cg, n, lid, t, pp) ->
DExternal (cc, add_if lid flags, n_cg, n, lid, t, pp)
in

(* Match a file against the given list of patterns. *)
Expand Down
40 changes: 26 additions & 14 deletions lib/Checker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ let populate_env files =
Warn.fatal_error "%a is polymorphic\n" plid lid;
let t = List.fold_right (fun b t2 -> TArrow (b.typ, t2)) binders ret in
{ env with globals = M.add lid t env.globals }
| DExternal (_, _, _, lid, typ, _) ->
| DExternal (_, _, _, _, lid, typ, _) ->
{ env with globals = M.add lid typ env.globals }
) env decls
) empty files
Expand Down Expand Up @@ -523,17 +523,36 @@ and best_buffer_type l t1 e2 =


and infer' env e =
let infer_app t es =
let t_ret, t_args = flatten_arrow t in
if List.length t_args = 0 then
checker_error env "This is not a function:\n%a" pexpr e;
if List.length es > List.length t_args then
checker_error env "Too many arguments for application:\n%a" pexpr e;
let t_args, t_remaining_args = KList.split (List.length es) t_args in
ignore (List.map2 (check_or_infer env) t_args es);
fold_arrow t_remaining_args t_ret
in

match e.node with
| ETApp (e, cs, ts) ->
begin match e.node with
| ETApp (e0, cs, ts) ->
begin match e0.node with
| EOp ((K.Eq | K.Neq), _) ->
(* Special incorrect encoding of polymorphic equalities *)
let t = KList.one ts in
TArrow (t, TArrow (t, TBool))
| _ ->
let t = infer env e in
let t = DeBruijn.subst_ctn env.n_cgs cs t in
DeBruijn.subst_tn ts t
let t = infer env e0 in
KPrint.bprintf "infer-cg: t=%a\n" ptyp t;
let diff = List.length env.locals - env.n_cgs in
let t = DeBruijn.subst_tn ts t in
KPrint.bprintf "infer-cg: subst_tn --> %a\n" ptyp t;
let t = DeBruijn.subst_ctn diff cs t in
KPrint.bprintf "infer-cg: subst_ctn --> %a\n" ptyp t;
(* Now type-check the application itself, after substitution *)
let t = infer_app t cs in
KPrint.bprintf "infer-cg: infer_app --> %a\n" ptyp t;
t
end

| EPolyComp (_, t) ->
Expand Down Expand Up @@ -578,14 +597,7 @@ and infer' env e =
let _ = List.map (infer env) es in
TAny
else
let t_ret, t_args = flatten_arrow t in
if List.length t_args = 0 then
checker_error env "This is not a function:\n%a" pexpr e;
if List.length es > List.length t_args then
checker_error env "Too many arguments for application:\n%a" pexpr e;
let t_args, t_remaining_args = KList.split (List.length es) t_args in
ignore (List.map2 (check_or_infer env) t_args es);
fold_arrow t_remaining_args t_ret
infer_app t es

| ELet (binder, body, cont) ->
let t = check_or_infer (locate env (In binder.node.name)) binder.typ body in
Expand Down
38 changes: 22 additions & 16 deletions lib/DeBruijn.ml
Original file line number Diff line number Diff line change
Expand Up @@ -261,28 +261,32 @@ class map_counting_cg = object
i, i' + 1
end

let cg_of_expr n_cg (_, i') e =
(* Converting an expression into a suitable const generic usable in types, knowing
`diff = n_cg - n_binders`, where
- n_cg is the total number of const generics in the current function / type,
and
- n_binders is the total number of expression binders traversed (including
const generics) *)
let cg_of_expr diff e =
match e.node with
| EBound k ->
let level = i' - k - 1 in
assert (n_cg - level - 1 >= 0);
`Var (n_cg - level - 1)
assert (k - diff > 0);
`Var (k - diff)
| EConstant (w, s) ->
`Const (w, s)
| _ ->
failwith "Unsuitable const generic"

(* Substitute const generics *)
class subst_c (n_cg: int) (c: expr) = object (self)
class subst_c (diff: int) (c: expr) = object (self)
inherit map_counting_cg
method! visit_TCgArray ((i, _) as env) t j =
let t = self#visit_typ env t in
match cg_of_expr n_cg env c with
match cg_of_expr diff c with
| `Var v' ->
(* we substitute v' for i in [ t; j ] *)
if j = i then
(* lift i c boils down to this since we never cross any binders *)
TCgArray (t, v' + i)
TCgArray (t, v' + i (* = lift_cg i v' *))
else
TCgArray (t, if j < i then j else j-1)
| `Const (w, s) ->
Expand All @@ -295,20 +299,22 @@ class subst_c (n_cg: int) (c: expr) = object (self)
EBound (if j < i then j else j-1)
end

let subst_ce n_cg c = (new subst_c n_cg c)#visit_expr_w
let subst_ct n_cg c = (new subst_c n_cg c)#visit_typ
(* Both of these function receive a cg debruijn index, whereas the argument c is
an expression that is in the expression debruijn space -- hence the extra diff
parameter to go one the latter to the former. *)
let subst_ce diff c i = (new subst_c diff c)#visit_expr_w (i, i + diff)
let subst_ct diff c i = (new subst_c diff c)#visit_typ (i, i + diff)

(*let subst_cen n_cg cs e =
let subst_cen diff cs t =
let l = List.length cs in
KList.fold_lefti (fun i body arg ->
let k = l - i - 1 in
subst_ce n_cg arg k body
) e cs*)
subst_ce diff arg k body
) t cs

let subst_ctn n_cg cs t =
let subst_ctn diff cs t =
let l = List.length cs in
KList.fold_lefti (fun i body arg ->
let k = l - i - 1 in
subst_ct n_cg arg (k, k) body
subst_ct diff arg k body
) t cs

6 changes: 3 additions & 3 deletions lib/Inlining.ml
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ let cross_call_analysis files =
limited, this is still useful e.g. in the presence of function
pointers. *)
(visit true)#visit_expr_w () e
| DExternal (_, _, _, _, t, _) ->
| DExternal (_, _, _, _, _, t, _) ->
(visit false)#visit_typ () t
| DType (_, flags, _, _, d) ->
if not (List.mem Common.AbstractStruct flags) then
Expand Down Expand Up @@ -465,8 +465,8 @@ let cross_call_analysis files =
DFunction (cc, adjust flags, n_cgs, n, t, name, bs, e)
| DGlobal (flags, name, n, t, e) ->
DGlobal (adjust flags, name, n, t, e)
| DExternal (cc, flags, n, name, t, hints) ->
DExternal (cc, adjust flags, n, name, t, hints)
| DExternal (cc, flags, n_cg, n, name, t, hints) ->
DExternal (cc, adjust flags, n_cg, n, name, t, hints)
| DType (name, flags, n_cgs, n, def) ->
DType (name, adjust flags, n_cgs, n, def)
) decls
Expand Down
4 changes: 2 additions & 2 deletions lib/InputAstToAst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,9 @@ let rec mk_decl = function
| I.DTypeFlat (name, flags, n, fields) ->
DType (name, flags, 0, n, Flat (mk_tfields_opt fields))
| I.DExternal (cc, flags, name, t) ->
DExternal (cc, flags, 0, name, mk_typ t, [])
DExternal (cc, flags, 0, 0, name, mk_typ t, [])
| I.DExternal2 (cc, flags, name, t, arg_names) ->
DExternal (cc, flags, 0, name, mk_typ t, arg_names)
DExternal (cc, flags, 0, 0, name, mk_typ t, arg_names)
| I.DTypeVariant (name, flags, n, branches) ->
DType (name, flags, 0, n,
Variant (List.map (fun (ident, fields) -> ident, mk_tfields fields) branches))
Expand Down
16 changes: 13 additions & 3 deletions lib/Monomorphization.ml
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,14 @@ let functions files =
) decls

method! visit_ETApp env e cgs ts =
assert (cgs = []);
let fail_if () =
if cgs <> [] then
Warn.fatal_error "TODO: e=%a\ncgs=%a\nts=%a\n%a\n"
pexpr e
pexprs cgs
ptyps ts
pexpr (with_type TUnit (ETApp (e, cgs, ts)));
in
match e.node with
| EQualified lid ->
begin try
Expand All @@ -452,12 +459,14 @@ let functions files =
with Not_found ->
match Hashtbl.find map lid with
| exception Not_found ->
(* External function. Bail. *)
(* External function. Bail. Leave cgs -- treated as normal
arguments when going to C. C'est la vie. *)
if !Options.allow_tapps || AstToCStar.whitelisted_tapp e then
super#visit_ETApp env e cgs ts
else
(self#visit_expr env e).node
| `Function (cc, flags, n_cgs, n, ret, name, binders, body) ->
fail_if ();
(* Need to generate a new instance. *)
if n <> List.length ts then begin
KPrint.bprintf "%a is not fully type-applied!\n" plid lid;
Expand All @@ -481,6 +490,7 @@ let functions files =
EQualified (Gen.register_def current_file lid ts name def)

| `Global (flags, name, n, t, body) ->
fail_if ();
if n <> List.length ts then begin
KPrint.bprintf "%a is not fully type-applied!\n" plid lid;
(self#visit_expr env e).node
Expand Down Expand Up @@ -609,7 +619,7 @@ let equalities files =
EQualified (Gen.register_def current_file eq_lid [ t ] instance_lid def)
| K.PEq ->
(* assume val __eq__t: t -> t -> bool *)
let def () = DExternal (None, [], 0, instance_lid, eq_typ', [ "x"; "y" ]) in
let def () = DExternal (None, [], 0, 0, instance_lid, eq_typ', [ "x"; "y" ]) in
EQualified (Gen.register_def current_file eq_lid [ t ] instance_lid def)
in

Expand Down
14 changes: 9 additions & 5 deletions lib/PrintAst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,19 @@ open Common
let arrow = string "->"
let lambda = fancystring "λ" 1

let print_app f head g arguments =
let print_app_ empty f head g arguments =
group (
f head ^^ jump (
if List.length arguments = 0 then
utf8string "😱"
utf8string empty
else
separate_map (break 1) g arguments
)
)

let print_app x = print_app_ "😱" x
let print_cg_app x = print_app_ "" x

let rec print_decl = function
| DFunction (cc, flags, n_cg, n, typ, name, binders, body) ->
let cc = match cc with Some cc -> print_cc cc ^^ break1 | None -> empty in
Expand All @@ -36,11 +39,12 @@ let rec print_decl = function
print_expr body
)

| DExternal (cc, flags, n, name, typ, _) ->
| DExternal (cc, flags, n_cg, n, name, typ, _) ->
let cc = match cc with Some cc -> print_cc cc ^^ break1 | None -> empty in
print_flags flags ^/^
group (cc ^^ string "external" ^/^ string (string_of_lident name) ^/^
langle ^^ int n ^^ rangle ^^ colon) ^^
langle ^^ string "cg: " ^^ int n_cg ^^ rangle ^^
jump (print_typ typ)

| DGlobal (flags, name, n, typ, expr) ->
Expand Down Expand Up @@ -225,9 +229,9 @@ and print_expr { node; typ } =
| EApp (e, es) ->
print_app print_expr e print_expr es
| ETApp (e, es, ts) ->
print_app (fun (e, ts) ->
print_cg_app (fun (e, ts) ->
print_app print_expr e (fun t -> group (langle ^/^ print_typ t ^/^ rangle)) ts
) (e, ts) print_expr es
) (e, ts) (fun e -> brackets (brackets (print_expr e))) es
| ELet (binder, e1, e2) ->
group (print_let_binding (binder, e1) ^/^ string "in") ^^ hardline ^^
group (print_expr e2)
Expand Down
Loading

0 comments on commit 6156d7d

Please sign in to comment.