From 6156d7de783024149e867891ae69370665c11826 Mon Sep 17 00:00:00 2001 From: Jonathan Protzenko Date: Wed, 6 Dec 2023 20:23:19 -0800 Subject: [PATCH] Redo substitutions and binder representation for safer usage --- lib/Ast.ml | 6 +++--- lib/AstToCFlat.ml | 2 +- lib/AstToCStar.ml | 10 ++++------ lib/AstToMiniRust.ml | 2 +- lib/Builtin.ml | 11 ++++++----- lib/Bundles.ml | 4 ++-- lib/Checker.ml | 40 ++++++++++++++++++++++++++-------------- lib/DeBruijn.ml | 38 ++++++++++++++++++++++---------------- lib/Inlining.ml | 6 +++--- lib/InputAstToAst.ml | 4 ++-- lib/Monomorphization.ml | 16 +++++++++++++--- lib/PrintAst.ml | 14 +++++++++----- lib/Simplify.ml | 9 ++++----- 13 files changed, 96 insertions(+), 66 deletions(-) diff --git a/lib/Ast.ml b/lib/Ast.ml index 6cfe1add..ed8bf363 100644 --- a/lib/Ast.ml +++ b/lib/Ast.ml @@ -420,7 +420,7 @@ and files = and decl = | DFunction of calling_convention option * flag list * int * int * typ * lident * binders_w * expr_w | DGlobal of flag list * lident * int * typ * expr_w - | DExternal of calling_convention option * flag list * int * lident * typ * string list + | DExternal of calling_convention option * flag list * int * int * lident * typ * string list (** String list: only for pretty-printing purposes, names of the first few * known arguments. *) | DType of lident * flag list * int * int * type_def @@ -639,7 +639,7 @@ let with_type typ node = let lid_of_decl = function | DFunction (_, _, _, _, _, lid, _, _) | DGlobal (_, lid, _, _, _) - | DExternal (_, _, _, lid, _, _) + | DExternal (_, _, _, _, lid, _, _) | DType (lid, _, _, _, _) -> lid @@ -647,7 +647,7 @@ let flags_of_decl = function | DFunction (_, flags, _, _, _, _, _, _) | DGlobal (flags, _, _, _, _) | DType (_, flags, _, _, _) - | DExternal (_, flags, _, _, _, _) -> + | DExternal (_, flags, _, _, _, _, _) -> flags let tuple_lid = [ "K" ], "" diff --git a/lib/AstToCFlat.ml b/lib/AstToCFlat.ml index 8a70b0ec..b45112ca 100644 --- a/lib/AstToCFlat.ml +++ b/lib/AstToCFlat.ml @@ -1078,7 +1078,7 @@ let mk_decl env (d: decl): env * CF.decl list = CF.Global (name, size, body, post_init, public) ] - | DExternal (_, _, _, lid, t, _) -> + | DExternal (_, _, _, _, lid, t, _) -> let name = GlobalNames.to_c_name env.names lid in match t with | TArrow _ -> diff --git a/lib/AstToCStar.ml b/lib/AstToCStar.ml index 193e59f5..061fc5f2 100644 --- a/lib/AstToCStar.ml +++ b/lib/AstToCStar.ml @@ -359,12 +359,10 @@ and mk_expr env in_stmt e = CStar.Constant c | EApp ({ node = ETApp (e, cgs, ts); _ }, es) when !Options.allow_tapps || whitelisted_tapp e -> - assert (cgs = []); - unit_to_void env e es (List.map (fun t -> CStar.Type (mk_type env t)) ts) + unit_to_void env e (cgs @ es) (List.map (fun t -> CStar.Type (mk_type env t)) ts) | ETApp (e, cgs, ts) when !Options.allow_tapps || whitelisted_tapp e -> - assert (cgs = []); - CStar.Call (mk_expr env e, List.map (fun t -> CStar.Type (mk_type env t)) ts) + unit_to_void env e cgs (List.map (fun t -> CStar.Type (mk_type env t)) ts) | EApp ({ node = EOp (op, w); _ }, [ _; _ ]) when is_arith op w -> fst (mk_arith env e) @@ -903,7 +901,7 @@ and mk_declaration m env d: (CStar.decl * _) option = mk_type env t, mk_expr env false body), []) - | DExternal (cc, flags, n, name, t, pp) -> + | DExternal (cc, flags, _, n, name, t, pp) -> if LidSet.mem name env.ifdefs || n > 0 then None else @@ -1006,7 +1004,7 @@ let mk_ifdefs_set files: LidSet.t = inherit [_] reduce as super method private zero = LidSet.empty method private plus = LidSet.union - method! visit_DExternal _env _cc (flags: flag list) _n (name: lident) _t _hints: LidSet.t = + method! visit_DExternal _env _cc (flags: flag list) _n_cg _n (name: lident) _t _hints: LidSet.t = if List.mem Common.IfDef flags then LidSet.singleton name else diff --git a/lib/AstToMiniRust.ml b/lib/AstToMiniRust.ml index 88406d50..7ced9c77 100644 --- a/lib/AstToMiniRust.ml +++ b/lib/AstToMiniRust.ml @@ -823,7 +823,7 @@ let translate_decl env (d: Ast.decl) = let env = push_global env lid (name, typ) in env, Some (MiniRust.Constant { name; typ; body; public }) - | Ast.DExternal (_, _, type_parameters, lid, t, _param_names) -> + | Ast.DExternal (_, _, _, type_parameters, lid, t, _param_names) -> let name = translate_unknown_lid lid in let env = push_global env lid (name, make_poly (translate_type env t) type_parameters) in env, None diff --git a/lib/Builtin.ml b/lib/Builtin.ml index 5fcd4626..613397a6 100644 --- a/lib/Builtin.ml +++ b/lib/Builtin.ml @@ -10,10 +10,10 @@ open Helpers let t_string = TQualified (["Prims"], "string") let mk_binop m n t = - DExternal (None, [ ], 0, (m, n), TArrow (t, TArrow (t, t)), [ "x"; "y" ]) + DExternal (None, [ ], 0, 0, (m, n), TArrow (t, TArrow (t, t)), [ "x"; "y" ]) let mk_val ?(flags=[]) ?(nvars=0) m n t = - DExternal (None, flags, nvars, (m, n), t, []) + DExternal (None, flags, 0, nvars, (m, n), t, []) let prims: file = let t = TInt K.CInt in @@ -351,15 +351,16 @@ let addendum = [ ] let make_abstract_function_or_global = function - | DFunction (cc, flags, _, n, t, name, bs, _) -> + | DFunction (cc, flags, n_cg, n, t, name, bs, _) -> let t = fold_arrow (List.map (fun b -> b.typ) bs) t in + assert (n_cg = 0); if n = 0 then - Some (DExternal (cc, flags, 0, name, t, List.map (fun x -> x.node.name) bs)) + Some (DExternal (cc, flags, 0, 0, name, t, List.map (fun x -> x.node.name) bs)) else None | DGlobal (flags, name, n, t, _) when not (List.mem Common.Macro flags) -> if n = 0 then - Some (DExternal (None, flags, 0, name, t, [])) + Some (DExternal (None, flags, 0, 0, name, t, [])) else None | DType (name, flags, _, _, _) when List.mem Common.AbstractStruct flags -> diff --git a/lib/Bundles.ml b/lib/Bundles.ml index f44c908a..d3ad6ada 100644 --- a/lib/Bundles.ml +++ b/lib/Bundles.ml @@ -54,8 +54,8 @@ let make_one_bundle (bundle: Bundle.t) (files: file list) (used: (int * Bundle.t DGlobal (add_if name flags, name, n, typ, body) | DType (lid, flags, n_cgs, n, def) -> DType (lid, add_if lid flags, n_cgs, n, def) - | DExternal (cc, flags, n, lid, t, pp) -> - DExternal (cc, add_if lid flags, n, lid, t, pp) + | DExternal (cc, flags, n_cg, n, lid, t, pp) -> + DExternal (cc, add_if lid flags, n_cg, n, lid, t, pp) in (* Match a file against the given list of patterns. *) diff --git a/lib/Checker.ml b/lib/Checker.ml index 7fb849e3..3d337fe5 100644 --- a/lib/Checker.ml +++ b/lib/Checker.ml @@ -130,7 +130,7 @@ let populate_env files = Warn.fatal_error "%a is polymorphic\n" plid lid; let t = List.fold_right (fun b t2 -> TArrow (b.typ, t2)) binders ret in { env with globals = M.add lid t env.globals } - | DExternal (_, _, _, lid, typ, _) -> + | DExternal (_, _, _, _, lid, typ, _) -> { env with globals = M.add lid typ env.globals } ) env decls ) empty files @@ -523,17 +523,36 @@ and best_buffer_type l t1 e2 = and infer' env e = + let infer_app t es = + let t_ret, t_args = flatten_arrow t in + if List.length t_args = 0 then + checker_error env "This is not a function:\n%a" pexpr e; + if List.length es > List.length t_args then + checker_error env "Too many arguments for application:\n%a" pexpr e; + let t_args, t_remaining_args = KList.split (List.length es) t_args in + ignore (List.map2 (check_or_infer env) t_args es); + fold_arrow t_remaining_args t_ret + in + match e.node with - | ETApp (e, cs, ts) -> - begin match e.node with + | ETApp (e0, cs, ts) -> + begin match e0.node with | EOp ((K.Eq | K.Neq), _) -> (* Special incorrect encoding of polymorphic equalities *) let t = KList.one ts in TArrow (t, TArrow (t, TBool)) | _ -> - let t = infer env e in - let t = DeBruijn.subst_ctn env.n_cgs cs t in - DeBruijn.subst_tn ts t + let t = infer env e0 in + KPrint.bprintf "infer-cg: t=%a\n" ptyp t; + let diff = List.length env.locals - env.n_cgs in + let t = DeBruijn.subst_tn ts t in + KPrint.bprintf "infer-cg: subst_tn --> %a\n" ptyp t; + let t = DeBruijn.subst_ctn diff cs t in + KPrint.bprintf "infer-cg: subst_ctn --> %a\n" ptyp t; + (* Now type-check the application itself, after substitution *) + let t = infer_app t cs in + KPrint.bprintf "infer-cg: infer_app --> %a\n" ptyp t; + t end | EPolyComp (_, t) -> @@ -578,14 +597,7 @@ and infer' env e = let _ = List.map (infer env) es in TAny else - let t_ret, t_args = flatten_arrow t in - if List.length t_args = 0 then - checker_error env "This is not a function:\n%a" pexpr e; - if List.length es > List.length t_args then - checker_error env "Too many arguments for application:\n%a" pexpr e; - let t_args, t_remaining_args = KList.split (List.length es) t_args in - ignore (List.map2 (check_or_infer env) t_args es); - fold_arrow t_remaining_args t_ret + infer_app t es | ELet (binder, body, cont) -> let t = check_or_infer (locate env (In binder.node.name)) binder.typ body in diff --git a/lib/DeBruijn.ml b/lib/DeBruijn.ml index 769af2d3..25512676 100644 --- a/lib/DeBruijn.ml +++ b/lib/DeBruijn.ml @@ -261,28 +261,32 @@ class map_counting_cg = object i, i' + 1 end -let cg_of_expr n_cg (_, i') e = +(* Converting an expression into a suitable const generic usable in types, knowing + `diff = n_cg - n_binders`, where + - n_cg is the total number of const generics in the current function / type, + and + - n_binders is the total number of expression binders traversed (including + const generics) *) +let cg_of_expr diff e = match e.node with | EBound k -> - let level = i' - k - 1 in - assert (n_cg - level - 1 >= 0); - `Var (n_cg - level - 1) + assert (k - diff > 0); + `Var (k - diff) | EConstant (w, s) -> `Const (w, s) | _ -> failwith "Unsuitable const generic" (* Substitute const generics *) -class subst_c (n_cg: int) (c: expr) = object (self) +class subst_c (diff: int) (c: expr) = object (self) inherit map_counting_cg method! visit_TCgArray ((i, _) as env) t j = let t = self#visit_typ env t in - match cg_of_expr n_cg env c with + match cg_of_expr diff c with | `Var v' -> (* we substitute v' for i in [ t; j ] *) if j = i then - (* lift i c boils down to this since we never cross any binders *) - TCgArray (t, v' + i) + TCgArray (t, v' + i (* = lift_cg i v' *)) else TCgArray (t, if j < i then j else j-1) | `Const (w, s) -> @@ -295,20 +299,22 @@ class subst_c (n_cg: int) (c: expr) = object (self) EBound (if j < i then j else j-1) end -let subst_ce n_cg c = (new subst_c n_cg c)#visit_expr_w -let subst_ct n_cg c = (new subst_c n_cg c)#visit_typ +(* Both of these function receive a cg debruijn index, whereas the argument c is + an expression that is in the expression debruijn space -- hence the extra diff + parameter to go one the latter to the former. *) +let subst_ce diff c i = (new subst_c diff c)#visit_expr_w (i, i + diff) +let subst_ct diff c i = (new subst_c diff c)#visit_typ (i, i + diff) -(*let subst_cen n_cg cs e = +let subst_cen diff cs t = let l = List.length cs in KList.fold_lefti (fun i body arg -> let k = l - i - 1 in - subst_ce n_cg arg k body - ) e cs*) + subst_ce diff arg k body + ) t cs -let subst_ctn n_cg cs t = +let subst_ctn diff cs t = let l = List.length cs in KList.fold_lefti (fun i body arg -> let k = l - i - 1 in - subst_ct n_cg arg (k, k) body + subst_ct diff arg k body ) t cs - diff --git a/lib/Inlining.ml b/lib/Inlining.ml index af22646d..3e5684ce 100644 --- a/lib/Inlining.ml +++ b/lib/Inlining.ml @@ -410,7 +410,7 @@ let cross_call_analysis files = limited, this is still useful e.g. in the presence of function pointers. *) (visit true)#visit_expr_w () e - | DExternal (_, _, _, _, t, _) -> + | DExternal (_, _, _, _, _, t, _) -> (visit false)#visit_typ () t | DType (_, flags, _, _, d) -> if not (List.mem Common.AbstractStruct flags) then @@ -465,8 +465,8 @@ let cross_call_analysis files = DFunction (cc, adjust flags, n_cgs, n, t, name, bs, e) | DGlobal (flags, name, n, t, e) -> DGlobal (adjust flags, name, n, t, e) - | DExternal (cc, flags, n, name, t, hints) -> - DExternal (cc, adjust flags, n, name, t, hints) + | DExternal (cc, flags, n_cg, n, name, t, hints) -> + DExternal (cc, adjust flags, n_cg, n, name, t, hints) | DType (name, flags, n_cgs, n, def) -> DType (name, adjust flags, n_cgs, n, def) ) decls diff --git a/lib/InputAstToAst.ml b/lib/InputAstToAst.ml index 9395902c..d17fd1bb 100644 --- a/lib/InputAstToAst.ml +++ b/lib/InputAstToAst.ml @@ -48,9 +48,9 @@ let rec mk_decl = function | I.DTypeFlat (name, flags, n, fields) -> DType (name, flags, 0, n, Flat (mk_tfields_opt fields)) | I.DExternal (cc, flags, name, t) -> - DExternal (cc, flags, 0, name, mk_typ t, []) + DExternal (cc, flags, 0, 0, name, mk_typ t, []) | I.DExternal2 (cc, flags, name, t, arg_names) -> - DExternal (cc, flags, 0, name, mk_typ t, arg_names) + DExternal (cc, flags, 0, 0, name, mk_typ t, arg_names) | I.DTypeVariant (name, flags, n, branches) -> DType (name, flags, 0, n, Variant (List.map (fun (ident, fields) -> ident, mk_tfields fields) branches)) diff --git a/lib/Monomorphization.ml b/lib/Monomorphization.ml index 55fecf33..7592031f 100644 --- a/lib/Monomorphization.ml +++ b/lib/Monomorphization.ml @@ -443,7 +443,14 @@ let functions files = ) decls method! visit_ETApp env e cgs ts = - assert (cgs = []); + let fail_if () = + if cgs <> [] then + Warn.fatal_error "TODO: e=%a\ncgs=%a\nts=%a\n%a\n" + pexpr e + pexprs cgs + ptyps ts + pexpr (with_type TUnit (ETApp (e, cgs, ts))); + in match e.node with | EQualified lid -> begin try @@ -452,12 +459,14 @@ let functions files = with Not_found -> match Hashtbl.find map lid with | exception Not_found -> - (* External function. Bail. *) + (* External function. Bail. Leave cgs -- treated as normal + arguments when going to C. C'est la vie. *) if !Options.allow_tapps || AstToCStar.whitelisted_tapp e then super#visit_ETApp env e cgs ts else (self#visit_expr env e).node | `Function (cc, flags, n_cgs, n, ret, name, binders, body) -> + fail_if (); (* Need to generate a new instance. *) if n <> List.length ts then begin KPrint.bprintf "%a is not fully type-applied!\n" plid lid; @@ -481,6 +490,7 @@ let functions files = EQualified (Gen.register_def current_file lid ts name def) | `Global (flags, name, n, t, body) -> + fail_if (); if n <> List.length ts then begin KPrint.bprintf "%a is not fully type-applied!\n" plid lid; (self#visit_expr env e).node @@ -609,7 +619,7 @@ let equalities files = EQualified (Gen.register_def current_file eq_lid [ t ] instance_lid def) | K.PEq -> (* assume val __eq__t: t -> t -> bool *) - let def () = DExternal (None, [], 0, instance_lid, eq_typ', [ "x"; "y" ]) in + let def () = DExternal (None, [], 0, 0, instance_lid, eq_typ', [ "x"; "y" ]) in EQualified (Gen.register_def current_file eq_lid [ t ] instance_lid def) in diff --git a/lib/PrintAst.ml b/lib/PrintAst.ml index 27a77b1b..6d9a1b82 100644 --- a/lib/PrintAst.ml +++ b/lib/PrintAst.ml @@ -13,16 +13,19 @@ open Common let arrow = string "->" let lambda = fancystring "λ" 1 -let print_app f head g arguments = +let print_app_ empty f head g arguments = group ( f head ^^ jump ( if List.length arguments = 0 then - utf8string "😱" + utf8string empty else separate_map (break 1) g arguments ) ) +let print_app x = print_app_ "😱" x +let print_cg_app x = print_app_ "□" x + let rec print_decl = function | DFunction (cc, flags, n_cg, n, typ, name, binders, body) -> let cc = match cc with Some cc -> print_cc cc ^^ break1 | None -> empty in @@ -36,11 +39,12 @@ let rec print_decl = function print_expr body ) - | DExternal (cc, flags, n, name, typ, _) -> + | DExternal (cc, flags, n_cg, n, name, typ, _) -> let cc = match cc with Some cc -> print_cc cc ^^ break1 | None -> empty in print_flags flags ^/^ group (cc ^^ string "external" ^/^ string (string_of_lident name) ^/^ langle ^^ int n ^^ rangle ^^ colon) ^^ + langle ^^ string "cg: " ^^ int n_cg ^^ rangle ^^ jump (print_typ typ) | DGlobal (flags, name, n, typ, expr) -> @@ -225,9 +229,9 @@ and print_expr { node; typ } = | EApp (e, es) -> print_app print_expr e print_expr es | ETApp (e, es, ts) -> - print_app (fun (e, ts) -> + print_cg_app (fun (e, ts) -> print_app print_expr e (fun t -> group (langle ^/^ print_typ t ^/^ rangle)) ts - ) (e, ts) print_expr es + ) (e, ts) (fun e -> brackets (brackets (print_expr e))) es | ELet (binder, e1, e2) -> group (print_let_binding (binder, e1) ^/^ string "in") ^^ hardline ^^ group (print_expr e2) diff --git a/lib/Simplify.ml b/lib/Simplify.ml index ebd512f7..fff1f2ed 100644 --- a/lib/Simplify.ml +++ b/lib/Simplify.ml @@ -247,7 +247,7 @@ let remove_unused_parameters = object (self) let binders = KList.filter_mapi (fun i b -> if unused i then None else Some b) binders in DFunction (cc, flags, n_cgs, n, ret, name, binders, body) - method! visit_DExternal (parameter_table, _ as env) cc flags n name t hints = + method! visit_DExternal (parameter_table, _ as env) cc flags n_cg n name t hints = let ret, args = flatten_arrow t in let hints = KList.filter_mapi (fun i arg -> if unused parameter_table dummy_lid args i then @@ -263,7 +263,7 @@ let remove_unused_parameters = object (self) ) args in let ret = self#visit_typ env ret in let t = fold_arrow args ret in - DExternal (cc, flags, n, name, t, hints) + DExternal (cc, flags, n_cg, n, name, t, hints) method! visit_TArrow (parameter_table, _ as env) t1 t2 = (* Important: the only entries in `parameter_table` are those which are @@ -1014,9 +1014,8 @@ and hoist_expr loc pos e = let mk node = { node; typ = e.typ } in match e.node with | ETApp (e, cgs, ts) -> - assert (cgs = []); let lhs, e = hoist_expr loc Unspecified e in - lhs, mk (ETApp (e, [], ts)) + lhs, mk (ETApp (e, cgs, ts)) | EBufNull | EAbort _ @@ -1370,7 +1369,7 @@ let record_toplevel_names = object (self) method! visit_DFunction env _ flags _ _ _ name _ _ = self#record env ~is_type:false ~is_external:false flags name - method! visit_DExternal env _ flags _ name _ _ = + method! visit_DExternal env _ flags _ _ name _ _ = self#record env ~is_type:false ~is_external:true flags name val forward = Hashtbl.create 41