Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#[contracts::requires(...)] + #[contracts::ensures(...)] #128045

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
9 changes: 8 additions & 1 deletion compiler/rustc_ast/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3316,11 +3316,18 @@ pub struct Impl {
pub items: ThinVec<P<AssocItem>>,
}

#[derive(Clone, Encodable, Decodable, Debug, Default)]
pub struct FnContract {
pub requires: Option<P<Expr>>,
pub ensures: Option<P<Expr>>,
}

#[derive(Clone, Encodable, Decodable, Debug)]
pub struct Fn {
pub defaultness: Defaultness,
pub generics: Generics,
pub sig: FnSig,
pub contract: Option<P<FnContract>>,
pub body: Option<P<Block>>,
}

Expand Down Expand Up @@ -3618,7 +3625,7 @@ mod size_asserts {
static_assert_size!(Block, 32);
static_assert_size!(Expr, 72);
static_assert_size!(ExprKind, 40);
static_assert_size!(Fn, 160);
static_assert_size!(Fn, 168);
static_assert_size!(ForeignItem, 88);
static_assert_size!(ForeignItemKind, 16);
static_assert_size!(GenericArg, 24);
Expand Down
29 changes: 26 additions & 3 deletions compiler/rustc_ast/src/mut_visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,10 @@ pub trait MutVisitor: Sized {
walk_flat_map_assoc_item(self, i, ctxt)
}

fn visit_contract(&mut self, c: &mut P<FnContract>) {
walk_contract(self, c);
}

fn visit_fn_decl(&mut self, d: &mut P<FnDecl>) {
walk_fn_decl(self, d);
}
Expand Down Expand Up @@ -973,6 +977,16 @@ fn walk_fn<T: MutVisitor>(vis: &mut T, kind: FnKind<'_>) {
}
}

fn walk_contract<T: MutVisitor>(vis: &mut T, contract: &mut P<FnContract>) {
let FnContract { requires, ensures } = contract.deref_mut();
if let Some(pred) = requires {
vis.visit_expr(pred);
}
if let Some(pred) = ensures {
vis.visit_expr(pred);
}
}

fn walk_fn_decl<T: MutVisitor>(vis: &mut T, decl: &mut P<FnDecl>) {
let FnDecl { inputs, output } = decl.deref_mut();
inputs.flat_map_in_place(|param| vis.flat_map_param(param));
Expand Down Expand Up @@ -1205,8 +1219,11 @@ impl WalkItemKind for ItemKind {
ItemKind::Const(item) => {
visit_const_item(item, vis);
}
ItemKind::Fn(box Fn { defaultness, generics, sig, body }) => {
ItemKind::Fn(box Fn { defaultness, generics, sig, contract, body }) => {
visit_defaultness(vis, defaultness);
if let Some(contract) = contract {
vis.visit_contract(contract)
};
vis.visit_fn(
FnKind::Fn(FnCtxt::Free, ident, sig, visibility, generics, body),
span,
Expand Down Expand Up @@ -1329,8 +1346,11 @@ impl WalkItemKind for AssocItemKind {
AssocItemKind::Const(item) => {
visit_const_item(item, visitor);
}
AssocItemKind::Fn(box Fn { defaultness, generics, sig, body }) => {
AssocItemKind::Fn(box Fn { defaultness, generics, sig, contract, body }) => {
visit_defaultness(visitor, defaultness);
if let Some(contract) = contract {
visitor.visit_contract(contract);
}
visitor.visit_fn(
FnKind::Fn(FnCtxt::Assoc(ctxt), ident, sig, visibility, generics, body),
span,
Expand Down Expand Up @@ -1476,8 +1496,11 @@ impl WalkItemKind for ForeignItemKind {
visitor.visit_ty(ty);
visit_opt(expr, |expr| visitor.visit_expr(expr));
}
ForeignItemKind::Fn(box Fn { defaultness, generics, sig, body }) => {
ForeignItemKind::Fn(box Fn { defaultness, generics, sig, contract, body }) => {
visit_defaultness(visitor, defaultness);
if let Some(contract) = contract {
visitor.visit_contract(contract);
}
visitor.visit_fn(
FnKind::Fn(FnCtxt::Foreign, ident, sig, visibility, generics, body),
span,
Expand Down
52 changes: 42 additions & 10 deletions compiler/rustc_ast/src/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,15 @@ impl BoundKind {
#[derive(Copy, Clone, Debug)]
pub enum FnKind<'a> {
/// E.g., `fn foo()`, `fn foo(&self)`, or `extern "Abi" fn foo()`.
Fn(FnCtxt, &'a Ident, &'a FnSig, &'a Visibility, &'a Generics, &'a Option<P<Block>>),
Fn(
FnCtxt,
&'a Ident,
&'a FnSig,
&'a Visibility,
&'a Generics,
&'a Option<P<FnContract>>,
&'a Option<P<Block>>,
),

/// E.g., `|x, y| body`.
Closure(&'a ClosureBinder, &'a Option<CoroutineKind>, &'a FnDecl, &'a Expr),
Expand All @@ -74,7 +82,7 @@ pub enum FnKind<'a> {
impl<'a> FnKind<'a> {
pub fn header(&self) -> Option<&'a FnHeader> {
match *self {
FnKind::Fn(_, _, sig, _, _, _) => Some(&sig.header),
FnKind::Fn(_, _, sig, _, _, _, _) => Some(&sig.header),
FnKind::Closure(..) => None,
}
}
Expand All @@ -88,7 +96,7 @@ impl<'a> FnKind<'a> {

pub fn decl(&self) -> &'a FnDecl {
match self {
FnKind::Fn(_, _, sig, _, _, _) => &sig.decl,
FnKind::Fn(_, _, sig, _, _, _, _) => &sig.decl,
FnKind::Closure(_, _, decl, _) => decl,
}
}
Expand Down Expand Up @@ -188,6 +196,9 @@ pub trait Visitor<'ast>: Sized {
fn visit_closure_binder(&mut self, b: &'ast ClosureBinder) -> Self::Result {
walk_closure_binder(self, b)
}
fn visit_contract(&mut self, c: &'ast FnContract) -> Self::Result {
walk_contract(self, c)
}
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
walk_where_predicate(self, p)
}
Expand Down Expand Up @@ -374,8 +385,8 @@ impl WalkItemKind for ItemKind {
try_visit!(visitor.visit_ty(ty));
visit_opt!(visitor, visit_expr, expr);
}
ItemKind::Fn(box Fn { defaultness: _, generics, sig, body }) => {
let kind = FnKind::Fn(FnCtxt::Free, ident, sig, vis, generics, body);
ItemKind::Fn(box Fn { defaultness: _, generics, sig, contract, body }) => {
let kind = FnKind::Fn(FnCtxt::Free, ident, sig, vis, generics, contract, body);
try_visit!(visitor.visit_fn(kind, span, id));
}
ItemKind::Mod(_unsafety, mod_kind) => match mod_kind {
Expand Down Expand Up @@ -715,8 +726,8 @@ impl WalkItemKind for ForeignItemKind {
try_visit!(visitor.visit_ty(ty));
visit_opt!(visitor, visit_expr, expr);
}
ForeignItemKind::Fn(box Fn { defaultness: _, generics, sig, body }) => {
let kind = FnKind::Fn(FnCtxt::Foreign, ident, sig, vis, generics, body);
ForeignItemKind::Fn(box Fn { defaultness: _, generics, sig, contract, body }) => {
let kind = FnKind::Fn(FnCtxt::Foreign, ident, sig, vis, generics, contract, body);
try_visit!(visitor.visit_fn(kind, span, id));
}
ForeignItemKind::TyAlias(box TyAlias {
Expand Down Expand Up @@ -800,6 +811,17 @@ pub fn walk_closure_binder<'a, V: Visitor<'a>>(
V::Result::output()
}

pub fn walk_contract<'a, V: Visitor<'a>>(visitor: &mut V, c: &'a FnContract) -> V::Result {
let FnContract { requires, ensures } = c;
if let Some(pred) = requires {
visitor.visit_expr(pred);
}
if let Some(pred) = ensures {
visitor.visit_expr(pred);
}
V::Result::output()
}

pub fn walk_where_predicate<'a, V: Visitor<'a>>(
visitor: &mut V,
predicate: &'a WherePredicate,
Expand Down Expand Up @@ -858,11 +880,20 @@ pub fn walk_fn_decl<'a, V: Visitor<'a>>(

pub fn walk_fn<'a, V: Visitor<'a>>(visitor: &mut V, kind: FnKind<'a>) -> V::Result {
match kind {
FnKind::Fn(_ctxt, _ident, FnSig { header, decl, span: _ }, _vis, generics, body) => {
FnKind::Fn(
_ctxt,
_ident,
FnSig { header, decl, span: _ },
_vis,
generics,
contract,
body,
) => {
// Identifier and visibility are visited as a part of the item.
try_visit!(visitor.visit_fn_header(header));
try_visit!(visitor.visit_generics(generics));
try_visit!(visitor.visit_fn_decl(decl));
visit_opt!(visitor, visit_contract, contract);
visit_opt!(visitor, visit_block, body);
}
FnKind::Closure(binder, coroutine_kind, decl, body) => {
Expand Down Expand Up @@ -892,8 +923,9 @@ impl WalkItemKind for AssocItemKind {
try_visit!(visitor.visit_ty(ty));
visit_opt!(visitor, visit_expr, expr);
}
AssocItemKind::Fn(box Fn { defaultness: _, generics, sig, body }) => {
let kind = FnKind::Fn(FnCtxt::Assoc(ctxt), ident, sig, vis, generics, body);
AssocItemKind::Fn(box Fn { defaultness: _, generics, sig, contract, body }) => {
let kind =
FnKind::Fn(FnCtxt::Assoc(ctxt), ident, sig, vis, generics, contract, body);
try_visit!(visitor.visit_fn(kind, span, id));
}
AssocItemKind::Type(box TyAlias {
Expand Down
19 changes: 16 additions & 3 deletions compiler/rustc_ast_lowering/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,20 @@ impl<'hir> LoweringContext<'_, 'hir> {
hir::ExprKind::Continue(self.lower_jump_destination(e.id, *opt_label))
}
ExprKind::Ret(e) => {
let e = e.as_ref().map(|x| self.lower_expr(x));
let mut e = e.as_ref().map(|x| self.lower_expr(x));
if let Some(Some((span, fresh_ident))) = self
.contract
.as_ref()
.map(|c| c.ensures.as_ref().map(|e| (e.expr.span, e.fresh_ident)))
{
let checker_fn = self.expr_ident(span, fresh_ident.0, fresh_ident.2);
let args = if let Some(e) = e {
std::slice::from_ref(e)
} else {
std::slice::from_ref(self.expr_unit(span))
};
e = Some(self.expr_call(span, checker_fn, args));
}
Comment on lines -317 to +330
Copy link
Contributor

@oli-obk oli-obk Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

returns kind of duplicate logic with the trailing expression, something can be improved here

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I agree. It should also be unified with how the Try operator is implemented. I.e. all three cases (return, ?, and the tail expression) should all call some common method so that code like this can be controlled in one place.

hir::ExprKind::Ret(e)
}
ExprKind::Yeet(sub_expr) => self.lower_expr_yeet(e.span, sub_expr.as_deref()),
Expand Down Expand Up @@ -2125,7 +2138,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.arena.alloc(self.expr_call_mut(span, e, args))
}

fn expr_call_lang_item_fn_mut(
pub(super) fn expr_call_lang_item_fn_mut(
&mut self,
span: Span,
lang_item: hir::LangItem,
Expand All @@ -2135,7 +2148,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
self.expr_call_mut(span, path, args)
}

fn expr_call_lang_item_fn(
pub(super) fn expr_call_lang_item_fn(
&mut self,
span: Span,
lang_item: hir::LangItem,
Expand Down
113 changes: 109 additions & 4 deletions compiler/rustc_ast_lowering/src/item.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,42 @@ impl<'hir> LoweringContext<'_, 'hir> {
sig: FnSig { decl, header, span: fn_sig_span },
generics,
body,
contract,
..
}) => {
self.with_new_scopes(*fn_sig_span, |this| {
assert!(this.contract.is_none());
if let Some(contract) = contract {
let requires = contract.requires.clone();
let ensures = contract.ensures.clone();
let ensures = if let Some(ens) = ensures {
// FIXME: this needs to be a fresh (or illegal) identifier to prevent
// accidental capture of a parameter or global variable.
let checker_ident: Ident =
Ident::from_str_and_span("__ensures_checker", ens.span);
let (checker_pat, checker_hir_id) = this.pat_ident_binding_mode_mut(
ens.span,
checker_ident,
hir::BindingMode::NONE,
);

Some(crate::FnContractLoweringEnsures {
expr: ens,
fresh_ident: (checker_ident, checker_pat, checker_hir_id),
})
} else {
None
};

// Note: `with_new_scopes` will reinstall the outer
// item's contract (if any) after its callback finishes.
this.contract.replace(crate::FnContractLoweringInfo {
span,
requires,
ensures,
});
}

// Note: we don't need to change the return type from `T` to
// `impl Future<Output = T>` here because lower_body
// only cares about the input argument patterns in the function
Expand Down Expand Up @@ -1051,10 +1084,82 @@ impl<'hir> LoweringContext<'_, 'hir> {
body: impl FnOnce(&mut Self) -> hir::Expr<'hir>,
) -> hir::BodyId {
self.lower_body(|this| {
(
this.arena.alloc_from_iter(decl.inputs.iter().map(|x| this.lower_param(x))),
body(this),
)
let params =
this.arena.alloc_from_iter(decl.inputs.iter().map(|x| this.lower_param(x)));
let result = body(this);

let contract = this.contract.take();

// { body }
// ==>
// { rustc_contract_requires(PRECOND); { body } }
let result: hir::Expr<'hir> = if let Some(_contract) = contract {
let lit_unit = |this: &mut LoweringContext<'_, 'hir>| {
this.expr(_contract.span, hir::ExprKind::Tup(&[]))
};

let precond: hir::Stmt<'hir> = if let Some(req) = _contract.requires {
let lowered_req = this.lower_expr_mut(&req);
let precond = this.expr_call_lang_item_fn_mut(
req.span,
hir::LangItem::ContractCheckRequires,
&*arena_vec![this; lowered_req],
);
this.stmt_expr(req.span, precond)
} else {
let u = lit_unit(this);
this.stmt_expr(_contract.span, u)
};

let (postcond_checker, _opt_ident, result) = if let Some(ens) = _contract.ensures {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably necessary for later in the PR, but, why is there a tuple field here that we just ignore

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I don't think I ever actually used it in the later commits. It was more that I had the info available from when I created the identifier and I wasn't sure if I might need it, so I threaded it through to at least this point...)

let crate::FnContractLoweringEnsures { expr: ens, fresh_ident } = ens;
let lowered_ens: hir::Expr<'hir> = this.lower_expr_mut(&ens);
let postcond_checker = this.expr_call_lang_item_fn(
ens.span,
hir::LangItem::ContractBuildCheckEnsures,
&*arena_vec![this; lowered_ens],
);
let checker_binding_pat = fresh_ident.1;
(
this.stmt_let_pat(
None,
ens.span,
Some(postcond_checker),
this.arena.alloc(checker_binding_pat),
hir::LocalSource::Contract,
),
Some((fresh_ident.0, fresh_ident.2)),
{
let checker_fn =
this.expr_ident(ens.span, fresh_ident.0, fresh_ident.2);
let span = this.mark_span_with_reason(
DesugaringKind::Contract,
ens.span,
None,
);
this.expr_call_mut(
span,
checker_fn,
std::slice::from_ref(this.arena.alloc(result)),
)
},
)
} else {
let u = lit_unit(this);
(this.stmt_expr(_contract.span, u), None, result)
};

let block = this.block_all(
_contract.span,
arena_vec![this; precond, postcond_checker],
Some(this.arena.alloc(result)),
);
this.expr_block(block)
} else {
result
};

(params, result)
})
}

Expand Down
Loading
Loading