Skip to content

Commit

Permalink
fix: Stop cloning Traits! (#3736)
Browse files Browse the repository at this point in the history
# Description

## Problem\*

The `Trait` struct is quite large and was being cloned every time
`get_trait` was called in the NodeResolver, which was somewhat often.

## Summary\*

Stops the cloning of traits! The `Clone` derive is removed entirely and
when needed for ownership reasons we either re-retrieve the trait by
calling `get_trait` again (which now returns a reference), or we
temporarily take ownership of the trait's methods.

## Additional Context



## Documentation\*

Check one:
- [x] No documentation needed.
- [ ] Documentation included in this PR.
- [ ] **[Exceptional Case]** Documentation to be submitted in a separate
PR.

# PR Checklist\*

- [x] I have tested the changes locally.
- [x] I have formatted the changes with [Prettier](https://prettier.io/)
and/or `cargo fmt` on default settings.

---------

Co-authored-by: kevaundray <kevtheappdev@gmail.com>
  • Loading branch information
jfecher and kevaundray authored Dec 11, 2023
1 parent 6076e08 commit fcff412
Show file tree
Hide file tree
Showing 11 changed files with 85 additions and 64 deletions.
19 changes: 12 additions & 7 deletions compiler/noirc_frontend/src/hir/def_collector/dc_crate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -410,12 +410,15 @@ pub(crate) fn check_methods_signatures(
trait_impl_generic_count: usize,
errors: &mut Vec<(CompilationError, FileId)>,
) {
let the_trait = resolver.interner.get_trait(trait_id);

let self_type = resolver.get_self_type().expect("trait impl must have a Self type");
let self_type = resolver.get_self_type().expect("trait impl must have a Self type").clone();

// Temporarily bind the trait's Self type to self_type so we can type check
the_trait.self_type_typevar.bind(self_type.clone());
let the_trait = resolver.interner.get_trait_mut(trait_id);
the_trait.self_type_typevar.bind(self_type);

// Temporarily take the trait's methods so we can use both them and a mutable reference
// to the interner within the loop.
let trait_methods = std::mem::take(&mut the_trait.methods);

for (file_id, func_id) in impl_methods {
let impl_method = resolver.interner.function_meta(func_id);
Expand All @@ -427,7 +430,7 @@ pub(crate) fn check_methods_signatures(
// If that's the case, a `MethodNotInTrait` error has already been thrown, and we can ignore
// the impl method, since there's nothing in the trait to match its signature against.
if let Some(trait_method) =
the_trait.methods.iter().find(|method| method.name.0.contents == func_name)
trait_methods.iter().find(|method| method.name.0.contents == func_name)
{
let impl_function_type = impl_method.typ.instantiate(resolver.interner);

Expand All @@ -442,7 +445,7 @@ pub(crate) fn check_methods_signatures(
let error = DefCollectorErrorKind::MismatchTraitImplementationNumGenerics {
impl_method_generic_count,
trait_method_generic_count,
trait_name: the_trait.name.to_string(),
trait_name: resolver.interner.get_trait(trait_id).name.to_string(),
method_name: func_name.to_string(),
span: impl_method.location.span,
};
Expand Down Expand Up @@ -472,7 +475,7 @@ pub(crate) fn check_methods_signatures(
let error = DefCollectorErrorKind::MismatchTraitImplementationNumParameters {
actual_num_parameters: impl_method.parameters.0.len(),
expected_num_parameters: trait_method.arguments().len(),
trait_name: the_trait.name.to_string(),
trait_name: resolver.interner.get_trait(trait_id).name.to_string(),
method_name: func_name.to_string(),
span: impl_method.location.span,
};
Expand All @@ -498,5 +501,7 @@ pub(crate) fn check_methods_signatures(
}
}

let the_trait = resolver.interner.get_trait_mut(trait_id);
the_trait.set_methods(trait_methods);
the_trait.self_type_typevar.unbind(the_trait.self_type_typevar_id);
}
14 changes: 7 additions & 7 deletions compiler/noirc_frontend/src/hir/def_collector/dc_mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ impl<'a> ModCollector<'a> {
let name = trait_definition.name.clone();

// Create the corresponding module for the trait namespace
let id = match self.push_child_module(&name, self.file_id, false, false) {
let trait_id = match self.push_child_module(&name, self.file_id, false, false) {
Ok(local_id) => TraitId(ModuleId { krate, local_id }),
Err(error) => {
errors.push((error.into(), self.file_id));
Expand All @@ -359,7 +359,7 @@ impl<'a> ModCollector<'a> {

// Add the trait to scope so its path can be looked up later
let result =
self.def_collector.def_map.modules[self.module_id.0].declare_trait(name, id);
self.def_collector.def_map.modules[self.module_id.0].declare_trait(name, trait_id);

if let Err((first_def, second_def)) = result {
let error = DefCollectorErrorKind::Duplicate {
Expand Down Expand Up @@ -400,9 +400,9 @@ impl<'a> ModCollector<'a> {
let location = Location::new(name.span(), self.file_id);
context
.def_interner
.push_function_definition(func_id, modifiers, id.0, location);
.push_function_definition(func_id, modifiers, trait_id.0, location);

match self.def_collector.def_map.modules[id.0.local_id.0]
match self.def_collector.def_map.modules[trait_id.0.local_id.0]
.declare_function(name.clone(), func_id)
{
Ok(()) => {
Expand Down Expand Up @@ -437,7 +437,7 @@ impl<'a> ModCollector<'a> {
let stmt_id = context.def_interner.push_empty_global();

if let Err((first_def, second_def)) = self.def_collector.def_map.modules
[id.0.local_id.0]
[trait_id.0.local_id.0]
.declare_global(name.clone(), stmt_id)
{
let error = DefCollectorErrorKind::Duplicate {
Expand All @@ -451,7 +451,7 @@ impl<'a> ModCollector<'a> {
TraitItem::Type { name } => {
// TODO(nickysn or alexvitkov): implement context.def_interner.push_empty_type_alias and get an id, instead of using TypeAliasId::dummy_id()
if let Err((first_def, second_def)) = self.def_collector.def_map.modules
[id.0.local_id.0]
[trait_id.0.local_id.0]
.declare_type_alias(name.clone(), TypeAliasId::dummy_id())
{
let error = DefCollectorErrorKind::Duplicate {
Expand All @@ -473,7 +473,7 @@ impl<'a> ModCollector<'a> {
trait_def: trait_definition,
fns_with_default_impl: unresolved_functions,
};
self.def_collector.collected_traits.insert(id, unresolved);
self.def_collector.collected_traits.insert(trait_id, unresolved);
}
errors
}
Expand Down
18 changes: 9 additions & 9 deletions compiler/noirc_frontend/src/hir/resolution/resolver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ impl<'a> Resolver<'a> {
_new_variables: &mut Generics,
) -> Type {
if let Some(t) = self.lookup_trait_or_error(path) {
Type::TraitAsType(t)
Type::TraitAsType(t.id, Rc::new(t.name.to_string()))
} else {
Type::Error
}
Expand Down Expand Up @@ -938,7 +938,7 @@ impl<'a> Resolver<'a> {
| Type::Constant(_)
| Type::NamedGeneric(_, _)
| Type::NotConstant
| Type::TraitAsType(_)
| Type::TraitAsType(..)
| Type::Forall(_, _) => (),

Type::Array(length, element_type) => {
Expand Down Expand Up @@ -1498,8 +1498,8 @@ impl<'a> Resolver<'a> {
self.interner.get_struct(type_id)
}

pub fn get_trait(&self, trait_id: TraitId) -> Trait {
self.interner.get_trait(trait_id)
pub fn get_trait_mut(&mut self, trait_id: TraitId) -> &mut Trait {
self.interner.get_trait_mut(trait_id)
}

fn lookup<T: TryFromModuleDefId>(&mut self, path: Path) -> Result<T, ResolverError> {
Expand Down Expand Up @@ -1542,9 +1542,9 @@ impl<'a> Resolver<'a> {
}

/// Lookup a given trait by name/path.
fn lookup_trait_or_error(&mut self, path: Path) -> Option<Trait> {
fn lookup_trait_or_error(&mut self, path: Path) -> Option<&mut Trait> {
match self.lookup(path) {
Ok(trait_id) => Some(self.get_trait(trait_id)),
Ok(trait_id) => Some(self.get_trait_mut(trait_id)),
Err(error) => {
self.push_err(error);
None
Expand Down Expand Up @@ -1592,9 +1592,9 @@ impl<'a> Resolver<'a> {
if name == SELF_TYPE_NAME {
let the_trait = self.interner.get_trait(trait_id);

if let Some(method) = the_trait.find_method(method.clone()) {
if let Some(method) = the_trait.find_method(method.0.contents.as_str()) {
let self_type = Type::TypeVariable(
the_trait.self_type_typevar,
the_trait.self_type_typevar.clone(),
crate::TypeVariableKind::Normal,
);
return Some((HirExpression::TraitMethodReference(method), self_type));
Expand Down Expand Up @@ -1628,7 +1628,7 @@ impl<'a> Resolver<'a> {
{
let the_trait = self.interner.get_trait(trait_id);
if let Some(method) =
the_trait.find_method(path.segments.last().unwrap().clone())
the_trait.find_method(path.segments.last().unwrap().0.contents.as_str())
{
let self_type = self.resolve_type(typ.clone());
return Some((HirExpression::TraitMethodReference(method), self_type));
Expand Down
31 changes: 20 additions & 11 deletions compiler/noirc_frontend/src/hir/resolution/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
def_map::{CrateDefMap, ModuleDefId, ModuleId},
Context,
},
hir_def::traits::{Trait, TraitConstant, TraitFunction, TraitImpl, TraitType},
hir_def::traits::{TraitConstant, TraitFunction, TraitImpl, TraitType},
node_interner::{FuncId, NodeInterner, TraitId},
Path, Shared, TraitItem, Type, TypeBinding, TypeVariableKind,
};
Expand Down Expand Up @@ -90,7 +90,7 @@ fn resolve_trait_methods(
});
let file = def_maps[&crate_id].file_id(unresolved_trait.module_id);

let mut res = vec![];
let mut functions = vec![];
let mut resolver_errors = vec![];
for item in &unresolved_trait.trait_def.items {
if let TraitItem::Function {
Expand Down Expand Up @@ -121,7 +121,8 @@ fn resolve_trait_methods(
});

// Ensure the trait is generic over the Self type as well
generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar));
let the_trait = resolver.interner.get_trait(trait_id);
generics.push((the_trait.self_type_typevar_id, the_trait.self_type_typevar.clone()));

let name = name.clone();
let span: Span = name.span();
Expand Down Expand Up @@ -149,11 +150,11 @@ fn resolve_trait_methods(
default_impl_file_id: unresolved_trait.file_id,
default_impl_module_id: unresolved_trait.module_id,
};
res.push(f);
functions.push(f);
resolver_errors.extend(take_errors_filter_self_not_resolved(file, resolver));
}
}
(res, resolver_errors)
(functions, resolver_errors)
}

fn collect_trait_impl_methods(
Expand All @@ -167,15 +168,18 @@ fn collect_trait_impl_methods(
// for a particular method, the default implementation will be added at that slot.
let mut ordered_methods = Vec::new();

let the_trait = interner.get_trait(trait_id);

// check whether the trait implementation is in the same crate as either the trait or the type
let mut errors =
check_trait_impl_crate_coherence(interner, &the_trait, trait_impl, crate_id, def_maps);
check_trait_impl_crate_coherence(interner, trait_id, trait_impl, crate_id, def_maps);
// set of function ids that have a corresponding method in the trait
let mut func_ids_in_trait = HashSet::new();

for method in &the_trait.methods {
// Temporarily take ownership of the trait's methods so we can iterate over them
// while also mutating the interner
let the_trait = interner.get_trait_mut(trait_id);
let methods = std::mem::take(&mut the_trait.methods);

for method in &methods {
let overrides: Vec<_> = trait_impl
.methods
.functions
Expand All @@ -197,7 +201,7 @@ fn collect_trait_impl_methods(
));
} else {
let error = DefCollectorErrorKind::TraitMissingMethod {
trait_name: the_trait.name.clone(),
trait_name: interner.get_trait(trait_id).name.clone(),
method_name: method.name.clone(),
trait_impl_span: trait_impl.object_type.span.expect("type must have a span"),
};
Expand All @@ -221,6 +225,10 @@ fn collect_trait_impl_methods(
}
}

// Restore the methods that were taken before the for loop
let the_trait = interner.get_trait_mut(trait_id);
the_trait.set_methods(methods);

// Emit MethodNotInTrait error for methods in the impl block that
// don't have a corresponding method signature defined in the trait
for (_, func_id, func) in &trait_impl.methods.functions {
Expand Down Expand Up @@ -299,7 +307,7 @@ pub(crate) fn collect_trait_impls(

fn check_trait_impl_crate_coherence(
interner: &mut NodeInterner,
the_trait: &Trait,
trait_id: TraitId,
trait_impl: &UnresolvedTraitImpl,
current_crate: CrateId,
def_maps: &BTreeMap<CrateId, CrateDefMap>,
Expand All @@ -316,6 +324,7 @@ fn check_trait_impl_crate_coherence(
_ => CrateId::Dummy,
};

let the_trait = interner.get_trait(trait_id);
if current_crate != the_trait.crate_id && current_crate != object_crate {
let error = DefCollectorErrorKind::TraitImplOrphaned {
span: trait_impl.object_type.span.expect("object type must have a span"),
Expand Down
4 changes: 3 additions & 1 deletion compiler/noirc_frontend/src/hir/type_check/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -891,7 +891,9 @@ impl<'interner> TypeChecker<'interner> {
}
}
}
Type::TraitAsType(_trait) => {
// TODO: We should allow method calls on `impl Trait`s eventually.
// For now it is fine since they are only allowed on return types.
Type::TraitAsType(..) => {
self.errors.push(TypeCheckError::UnresolvedMethodCall {
method_name: method_name.to_string(),
object_type: object_type.clone(),
Expand Down
4 changes: 2 additions & 2 deletions compiler/noirc_frontend/src/hir/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,8 @@ pub fn type_check_func(interner: &mut NodeInterner, func_id: FuncId) -> Vec<Type
if !can_ignore_ret {
let (expr_span, empty_function) = function_info(interner, function_body_id);
let func_span = interner.expr_span(function_body_id); // XXX: We could be more specific and return the span of the last stmt, however stmts do not have spans yet
if let Type::TraitAsType(t) = &declared_return_type {
if interner.lookup_trait_implementation(&function_last_type, t.id).is_err() {
if let Type::TraitAsType(trait_id, _) = &declared_return_type {
if interner.lookup_trait_implementation(&function_last_type, *trait_id).is_err() {
let error = TypeCheckError::TypeMismatchWithSource {
expected: declared_return_type.clone(),
actual: function_last_type,
Expand Down
7 changes: 4 additions & 3 deletions compiler/noirc_frontend/src/hir_def/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ pub struct TraitType {
/// Represents a trait in the type system. Each instance of this struct
/// will be shared across all Type::Trait variants that represent
/// the same trait.
#[derive(Debug, Eq, Clone)]
#[derive(Debug, Eq)]
pub struct Trait {
/// A unique id representing this trait type. Used to check if two
/// struct traits are equal.
Expand All @@ -42,6 +42,7 @@ pub struct Trait {
pub crate_id: CrateId,

pub methods: Vec<TraitFunction>,

pub constants: Vec<TraitConstant>,
pub types: Vec<TraitType>,

Expand Down Expand Up @@ -124,9 +125,9 @@ impl Trait {
self.methods = methods;
}

pub fn find_method(&self, name: Ident) -> Option<TraitMethodId> {
pub fn find_method(&self, name: &str) -> Option<TraitMethodId> {
for (idx, method) in self.methods.iter().enumerate() {
if method.name == name {
if &method.name == name {
return Some(TraitMethodId { trait_id: self.id, method_index: idx });
}
}
Expand Down
Loading

0 comments on commit fcff412

Please sign in to comment.