Skip to content

Commit

Permalink
Support type arguments in methods
Browse files Browse the repository at this point in the history
Respect type arguments provided to methods and static methods
(previously they were being ignored). This also fixes a long-standing
issue that's been hanging out as a TODO in `example.abra` forever.
  • Loading branch information
kengorab committed Oct 12, 2023
1 parent 08ffde0 commit fb19518
Show file tree
Hide file tree
Showing 3 changed files with 273 additions and 164 deletions.
32 changes: 13 additions & 19 deletions abra_cli/abra-files/example.abra
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
val hello = "hello"

type Person {
firstName: String
lastName: String
//type Foo<T> {
// func make<T>(): Foo<T> = Foo()
//}

func getFullName(self): String = self.firstName + " " + self.lastName
func foo(self): Int = hello.length
}
//val f: Foo<Float> = Foo.make<Int>()
//val _: Int = f

val p = Person(firstName: "Ken", lastName: "Gorab")
println(p.foo())
//func make<T>(): T[] = []
//val arr: Float[] = make<Int>()

//func f(i: Int) {
// println(i + hello.length)
//}
//
//f(0)
//func make<T>(t: T): T[] = []
//val arr = make<Int>("f")

// TODO: So should this
//func makeArray<U>(): U[] = []
//type Foo<T> {
// a: T[] = makeArray()
//}
func makeArray<U>(): U[] = []
type Foo<T> {
a: T[] = makeArray()
}
72 changes: 44 additions & 28 deletions abra_core/src/typechecker/typechecker2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ pub enum TypedNode {
Identifier { token: Token, var_id: VarId, type_arg_ids: Vec<(TypeId, Range)>, type_id: TypeId },
NoneValue { token: Token, type_id: TypeId },
Invocation { target: Box<TypedNode>, arguments: Vec<Option<TypedNode>>, type_id: TypeId },
Accessor { target: Box<TypedNode>, kind: AccessorKind, is_opt_safe: bool, member_idx: usize, member_span: Range, type_id: TypeId },
Accessor { target: Box<TypedNode>, kind: AccessorKind, is_opt_safe: bool, member_idx: usize, member_span: Range, type_id: TypeId, type_arg_ids: Vec<(TypeId, Range)> },
Indexing { target: Box<TypedNode>, index: IndexingMode<TypedNode>, type_id: TypeId },
Lambda { span: Range, func_id: FuncId, type_id: TypeId },
Assignment { span: Range, kind: AssignmentKind, type_id: TypeId, expr: Box<TypedNode> },
Expand Down Expand Up @@ -4629,6 +4629,14 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
}
}

let mut type_arg_ids = Vec::with_capacity(type_args.as_ref().map(|args| args.len()).unwrap_or(0));
if let Some(type_args) = type_args {
for type_arg in type_args {
let type_id = self.resolve_type_identifier(&type_arg)?;
type_arg_ids.push((type_id, type_arg.get_ident().get_range()));
}
}

if let Some(alias_module_id) = target_type_id.as_module_type_alias() {
let m = &self.project.modules[alias_module_id.0];
let Some(export) = m.exports.iter().find_map(|(name, val)| if name == &field_name { Some(val) } else { None }) else {
Expand All @@ -4642,14 +4650,6 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
};
let type_id = self.project.get_var_by_id(&var_id).type_id;

let mut type_arg_ids = Vec::with_capacity(type_args.as_ref().map(|args| args.len()).unwrap_or(0));
if let Some(type_args) = type_args {
for type_arg in type_args {
let type_id = self.resolve_type_identifier(&type_arg)?;
type_arg_ids.push((type_id, type_arg.get_ident().get_range()));
}
}

return Ok(TypedNode::Identifier { token: field_ident, var_id, type_arg_ids, type_id });
}

Expand Down Expand Up @@ -4766,7 +4766,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
type_id = self.add_or_find_type_id(self.project.option_type(type_id))
}

Ok(TypedNode::Accessor { target: Box::new(typed_target), kind, is_opt_safe: n.is_opt_safe, member_idx, member_span: field_ident.get_range(), type_id })
Ok(TypedNode::Accessor { target: Box::new(typed_target), kind, is_opt_safe: n.is_opt_safe, member_idx, member_span: field_ident.get_range(), type_id, type_arg_ids })
} else {
Err(TypeError::UnknownMember { span: field_span, field_name, type_id: target_type_id })
}
Expand All @@ -4785,28 +4785,30 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
};

let params_data;
let fn_generic_ids;
let provided_type_arg_ids;
let mut return_type_id;
let mut is_instantiation = false;
let mut fn_is_variadic = false;
let mut forbid_labels = false;
match &typed_target {
TypedNode::Identifier { var_id, type_arg_ids: type_args, .. } if self.project.get_var_by_id(var_id).alias != VariableAlias::None => {
TypedNode::Identifier { var_id, type_arg_ids, .. } if self.project.get_var_by_id(var_id).alias != VariableAlias::None => {
provided_type_arg_ids = type_arg_ids.clone();
let var = self.project.get_var_by_id(var_id);

let generic_ids;
match var.alias {
VariableAlias::Function(alias_func_id) => {
let function = self.project.get_func_by_id(&alias_func_id);
fn_is_variadic = function.is_variadic();
generic_ids = &function.generic_ids;
fn_generic_ids = function.generic_ids.clone();
params_data = function.params.iter().enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
}
VariableAlias::Type(id) => {
match id {
TypeKind::Struct(alias_struct_id) => {
let struct_ = self.project.get_struct_by_id(&alias_struct_id);
generic_ids = &struct_.generic_ids;
fn_generic_ids = struct_.generic_ids.clone();
params_data = struct_.fields.iter().enumerate().map(|(idx, f)| (idx, f.name.clone(), f.type_id, f.default_value.is_some(), false)).collect_vec();
return_type_id = struct_.self_type_id;
is_instantiation = true;
Expand All @@ -4819,21 +4821,10 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
}
VariableAlias::None => unreachable!("VariableAlias::None identifiers are excluded from this match case and are handled below"),
}
if !type_args.is_empty() && type_args.len() != generic_ids.len() {
let span = if type_args.len() > generic_ids.len() {
let (_, span) = &type_args[generic_ids.len()];
span.clone()
} else {
typed_target.span()
};
let span = self.make_span(&span);
return Err(TypeError::InvalidTypeArgumentArity { span, num_required_args: generic_ids.len(), num_provided_args: type_args.len() });
}
for (generic_id, (type_arg_id, _)) in generic_ids.iter().zip(type_args.iter()) {
filled_in_generic_types.insert(*generic_id, *type_arg_id);
}
}
TypedNode::Accessor { target, kind, member_idx, is_opt_safe, .. } => {
TypedNode::Accessor { target, kind, member_idx, is_opt_safe, type_arg_ids, .. } => {
provided_type_arg_ids = type_arg_ids.clone();

let mut target_type_id = *target.type_id();
let mut target_is_option_type = false;
if *is_opt_safe {
Expand All @@ -4853,6 +4844,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
let Type::Function(param_type_ids, num_required_args, is_variadic, ret_type_id) = field_ty else {
return Err(TypeError::IllegalInvocation { span: self.make_span(&typed_target.span()), type_id: target_type_id });
};
fn_generic_ids = vec![]; // Cannot determine whether a function accepts type args solely based on its type
fn_is_variadic = *is_variadic;
let num_param_type_ids = param_type_ids.len();
params_data = param_type_ids.iter().enumerate()
Expand All @@ -4866,6 +4858,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
}
AccessorKind::Method => {
let function = self.project.get_func_by_id(&struct_.methods[*member_idx]);
fn_generic_ids = function.generic_ids.clone();
fn_is_variadic = function.is_variadic();
params_data = function.params.iter().skip(1).enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
Expand All @@ -4884,6 +4877,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
AccessorKind::Field => todo!(),
AccessorKind::Method => {
let function = self.project.get_func_by_id(&enum_.methods[*member_idx]);
fn_generic_ids = function.generic_ids.clone();
fn_is_variadic = function.is_variadic();
params_data = function.params.iter().skip(1).enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
Expand All @@ -4902,13 +4896,15 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
};
let struct_ = self.project.get_struct_by_id(&struct_id);
let function = self.project.get_func_by_id(&struct_.methods[*member_idx]);
fn_generic_ids = function.generic_ids.clone();
fn_is_variadic = function.is_variadic();
params_data = function.params.iter().skip(1).enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
}
Type::Type(TypeKind::Struct(struct_id)) => {
let struct_ = self.project.get_struct_by_id(struct_id);
let function = self.project.get_func_by_id(&struct_.static_methods[*member_idx]);
fn_generic_ids = function.generic_ids.clone();
fn_is_variadic = function.is_variadic();
params_data = function.params.iter().enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
Expand All @@ -4920,6 +4916,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
return Err(TypeError::IllegalEnumVariantConstruction { span: self.make_span(&typed_target.span()), enum_id: *enum_id, variant_idx: *member_idx });
};
let function = self.project.get_func_by_id(&func_id);
fn_generic_ids = function.generic_ids.clone();
fn_is_variadic = function.is_variadic();
params_data = function.params.iter().enumerate().map(|(idx, p)| (idx, p.name.clone(), p.type_id, is_param_optional(&p), p.is_variadic)).collect_vec();
return_type_id = function.return_type_id;
Expand All @@ -4934,11 +4931,13 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
}
}
_ => {
provided_type_arg_ids = vec![];
let target_type_id = typed_target.type_id();
let target_ty = self.project.get_type_by_id(target_type_id);
let Type::Function(param_type_ids, num_required_args, is_variadic, ret_type_id) = target_ty else {
return Err(TypeError::IllegalInvocation { span: self.make_span(&typed_target.span()), type_id: *target_type_id });
};
fn_generic_ids = vec![];
fn_is_variadic = *is_variadic;
let num_param_type_ids = param_type_ids.len();
params_data = param_type_ids.iter().enumerate()
Expand All @@ -4957,6 +4956,23 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
return Ok(TypedNode::Invocation { target: Box::new(typed_target), arguments: vec![], type_id: return_type_id });
}

if !provided_type_arg_ids.is_empty() && provided_type_arg_ids.len() != fn_generic_ids.len() {
let span = if provided_type_arg_ids.len() > fn_generic_ids.len() {
let (_, span) = &provided_type_arg_ids[fn_generic_ids.len()];
span.clone()
} else {
typed_target.span()
};
let span = self.make_span(&span);
return Err(TypeError::InvalidTypeArgumentArity { span, num_required_args: fn_generic_ids.len(), num_provided_args: provided_type_arg_ids.len() });
}
for (generic_id, (type_arg_id, _)) in fn_generic_ids.iter().zip(provided_type_arg_ids.iter()) {
filled_in_generic_types.insert(*generic_id, *type_arg_id);
}

if self.type_contains_generics(&return_type_id) {
return_type_id = self.substitute_generics_with_known(&return_type_id, &filled_in_generic_types);
}
if let Some(type_hint) = type_hint {
if self.type_contains_generics(&return_type_id) {
self.extract_values_for_generics(&type_hint, &return_type_id, &mut filled_in_generic_types);
Expand Down
Loading

0 comments on commit fb19518

Please sign in to comment.