Skip to content

Commit

Permalink
Enums: basic/constant variants
Browse files Browse the repository at this point in the history
Typechecker2: Add default toString/hash/eq method stubs onto each `enum`
declaration, similar to how it works for `type`s.
LLVM2: Add compilation for enums with basic/constant variants, in
addition to default implementations of toString/hash/eq.
  • Loading branch information
kengorab committed Dec 27, 2023
1 parent 0198823 commit 8532ff6
Show file tree
Hide file tree
Showing 6 changed files with 607 additions and 30 deletions.
57 changes: 52 additions & 5 deletions abra_cli/abra-files/example.abra
Original file line number Diff line number Diff line change
@@ -1,10 +1,57 @@
// TODO: This shouldn't stackoverflow
//println([])

val i: Int? = 17

if !i {
println("i is None")
enum Color {
Red
Blue
Green
}

println("done")
val colors = [Color.Red, Color.Green, Color.Blue]
println(colors[3]?.toString())
println(colors[2]?.toString())

//enum Color {
// Red
// Blue
// Green
//
// func hex(self): String {
// if self == Color.Red {
// "0xFF0000"
// } else if self == Color.Green {
// "0x00FF00"
// } else if self == Color.Blue {
// "0x0000FF"
// } else {
// "unreachable"
// }
// }
//}
//
//val r = Color.Red
//val b = Color.Blue
//val g = Color.Green
//
////println(b)
////r.toString()
////b.hash()
//
////r == r
//
////func hex(color: Color): String {
//// if color == Color.Red {
//// "0xFF0000"
//// } else if color == Color.Green {
//// "0x00FF00"
//// } else if color == Color.Blue {
//// "0x0000FF"
//// } else {
//// "unreachable"
//// }
////}
//
////hex(b)
//
//r.hex()
//
121 changes: 112 additions & 9 deletions abra_core/src/typechecker/typechecker2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -696,9 +696,13 @@ impl Type {
}

pub fn get_method(&self, project: &Project, method_idx: usize) -> Option<FuncId> {
let Some(struct_id) = self.get_struct_id(project) else { return None; };

Some(project.get_struct_by_id(&struct_id).methods[method_idx])
if let Some(struct_id) = self.get_struct_id(project) {
Some(project.get_struct_by_id(&struct_id).methods[method_idx])
} else if let Type::GenericEnumInstance(enum_id, _, _) = self {
Some(project.get_enum_by_id(enum_id).methods[method_idx])
} else {
return None;
}
}

pub fn get_static_method(&self, project: &Project, static_method_idx: usize) -> Option<FuncId> {
Expand All @@ -712,10 +716,16 @@ impl Type {
}

pub fn find_method_by_name<'a, S: AsRef<str>>(&self, project: &'a Project, method_name: S) -> Option<(usize, &'a FuncId)> {
let Some(struct_id) = self.get_struct_id(project) else { return None; };

let method_name = method_name.as_ref();
project.get_struct_by_id(&struct_id).methods.iter().enumerate().find(|(_, m)| &project.get_func_by_id(m).name == method_name)
let methods = if let Some(struct_id) = self.get_struct_id(project) {
&project.get_struct_by_id(&struct_id).methods
} else if let Type::GenericEnumInstance(enum_id, _, _) = self {
&project.get_enum_by_id(enum_id).methods
} else {
return None;
};

methods.iter().enumerate().find(|(_, m)| &project.get_func_by_id(m).name == method_name)
}

fn get_struct_id(&self, project: &Project) -> Option<StructId> {
Expand Down Expand Up @@ -3558,6 +3568,7 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
let EnumDeclNode { variants, methods, .. } = node;

let enum_ = self.project.get_enum_by_id(enum_id);
let self_type_id = enum_.self_type_id;

self.current_scope_id = enum_.enum_scope_id;
self.current_type_decl = Some(enum_.self_type_id);
Expand All @@ -3579,12 +3590,23 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
self.project.get_enum_by_id_mut(enum_id).variants.push(variant);
}

let mut tostring_func_id = None;
let mut hash_func_id = None;
let mut eq_func_id = None;
for method in methods {
let AstNode::FunctionDecl(_, decl_node) = method else { unreachable!("Internal error: an enum's methods must be of type AstNode::FunctionDecl") };
let func_id = self.typecheck_function_pass_0(decl_node)?;

let is_method = decl_node.args.get(0).map(|(token, _, _, _)| matches!(token, Token::Self_(_))).unwrap_or(false);
if is_method {
if Token::get_ident_name(&decl_node.name) == "toString" {
tostring_func_id = Some(func_id);
} else if Token::get_ident_name(&decl_node.name) == "hash" {
hash_func_id = Some(func_id);
} else if Token::get_ident_name(&decl_node.name) == "eq" {
eq_func_id = Some(func_id);
}

self.project.get_enum_by_id_mut(enum_id).methods.push(func_id);
} else {
let enum_ = self.project.get_enum_by_id(&enum_id);
Expand All @@ -3600,6 +3622,79 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
self.typecheck_function_pass_1(&func_id, decl_node, true)?;
}

let tostring_func_id = if let Some(func_id) = tostring_func_id {
let tostring_func_idx = self.project.get_enum_by_id(enum_id).methods.iter().find_position(|m| m == &&func_id).unwrap().0;
self.project.get_enum_by_id_mut(enum_id).methods.remove(tostring_func_idx);
func_id
} else {
let string_type_id = self.add_or_find_type_id(Type::Primitive(PrimitiveType::String));
let tostring_func_id = self.add_function_to_current_scope(
ScopeId::BOGUS,
&Token::Ident(POSITION_BOGUS, "toString".to_string()),
vec![],
true,
vec![
FunctionParam { name: "self".to_string(), type_id: self_type_id, var_id: VarId::BOGUS, defined_span: None, default_value: None, is_variadic: false, is_incomplete: false }
],
string_type_id,
)?;
self.project.get_func_by_id_mut(&tostring_func_id).defined_span = None;
let func = self.project.get_func_by_id(&tostring_func_id);
let fn_type_id = self.add_or_find_type_id(self.project.function_type_for_function(&func));
self.project.get_func_by_id_mut(&tostring_func_id).fn_type_id = fn_type_id;
tostring_func_id
};
self.project.get_enum_by_id_mut(enum_id).methods.insert(METHOD_IDX_TOSTRING, tostring_func_id);

let hash_func_id = if let Some(func_id) = hash_func_id {
let hash_func_idx = self.project.get_enum_by_id(enum_id).methods.iter().find_position(|m| m == &&func_id).unwrap().0;
self.project.get_enum_by_id_mut(enum_id).methods.remove(hash_func_idx);
func_id
} else {
let int_type_id = self.add_or_find_type_id(Type::Primitive(PrimitiveType::Int));
let hash_func_id = self.add_function_to_current_scope(
ScopeId::BOGUS,
&Token::Ident(POSITION_BOGUS, "hash".to_string()),
vec![],
true,
vec![
FunctionParam { name: "self".to_string(), type_id: self_type_id, var_id: VarId::BOGUS, defined_span: None, default_value: None, is_variadic: false, is_incomplete: false }
],
int_type_id,
)?;
self.project.get_func_by_id_mut(&hash_func_id).defined_span = None;
let func = self.project.get_func_by_id(&hash_func_id);
let fn_type_id = self.add_or_find_type_id(self.project.function_type_for_function(&func));
self.project.get_func_by_id_mut(&hash_func_id).fn_type_id = fn_type_id;
hash_func_id
};
self.project.get_enum_by_id_mut(enum_id).methods.insert(METHOD_IDX_HASH, hash_func_id);

let eq_func_id = if let Some(func_id) = eq_func_id {
let eq_func_idx = self.project.get_enum_by_id(enum_id).methods.iter().find_position(|m| m == &&func_id).unwrap().0;
self.project.get_enum_by_id_mut(enum_id).methods.remove(eq_func_idx);
func_id
} else {
let bool_type_id = self.add_or_find_type_id(Type::Primitive(PrimitiveType::Bool));
let eq_func_id = self.add_function_to_current_scope(
ScopeId::BOGUS,
&Token::Ident(POSITION_BOGUS, "eq".to_string()),
vec![],
true,
vec![
FunctionParam { name: "self".to_string(), type_id: self_type_id, var_id: VarId::BOGUS, defined_span: None, default_value: None, is_variadic: false, is_incomplete: false },
FunctionParam { name: "other".to_string(), type_id: self_type_id, var_id: VarId::BOGUS, defined_span: None, default_value: None, is_variadic: false, is_incomplete: false },
],
bool_type_id,
)?;
self.project.get_func_by_id_mut(&eq_func_id).defined_span = None;
let func = self.project.get_func_by_id(&eq_func_id);
let fn_type_id = self.add_or_find_type_id(self.project.function_type_for_function(&func));
self.project.get_func_by_id_mut(&eq_func_id).fn_type_id = fn_type_id;
eq_func_id
};
self.project.get_enum_by_id_mut(enum_id).methods.insert(METHOD_IDX_EQ, eq_func_id);

self.current_type_decl = None;

self.end_child_scope();
Expand Down Expand Up @@ -3640,16 +3735,24 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
}
}

let mut method_func_id_idx = 0;
let mut method_func_id_idx = 3;
let mut static_method_func_id_idx = 0;
for method in methods.into_iter() {
let AstNode::FunctionDecl(_, decl_node) = method else { unreachable!("Internal error: an enum's methods must be of type AstNode::FunctionDecl") };
let is_method = decl_node.args.get(0).map(|(token, _, _, _)| if let Token::Self_(_) = token { true } else { false }).unwrap_or(false);

let enum_ = self.project.get_enum_by_id(&enum_id);
let func_id = if is_method {
method_func_id_idx += 1;
enum_.methods[method_func_id_idx - 1]
if Token::get_ident_name(&decl_node.name) == "toString" {
enum_.methods[METHOD_IDX_TOSTRING]
} else if Token::get_ident_name(&decl_node.name) == "hash" {
enum_.methods[METHOD_IDX_HASH]
} else if Token::get_ident_name(&decl_node.name) == "eq" {
enum_.methods[METHOD_IDX_EQ]
} else {
method_func_id_idx += 1;
enum_.methods[method_func_id_idx - 1]
}
} else {
static_method_func_id_idx += 1;
enum_.static_methods[static_method_func_id_idx - 1]
Expand Down
12 changes: 9 additions & 3 deletions abra_core/src/typechecker/typechecker2_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2624,7 +2624,13 @@ fn typecheck_enum_declaration() {
").unwrap();
let module = &project.modules[1];
let enum_id = EnumId(ModuleId(1), 0);
let baz_func_id = FuncId(ScopeId(ModuleId(1), 1), 0);
let baz_func_id = FuncId(ScopeId(ModuleId(1), 1), 3);
let tostring_func_id = FuncId(ScopeId(ModuleId(1), 1), 0);
assert_eq!("toString", project.get_func_by_id(&tostring_func_id).name);
let hash_func_id = FuncId(ScopeId(ModuleId(1), 1), 1);
assert_eq!("hash", project.get_func_by_id(&hash_func_id).name);
let eq_func_id = FuncId(ScopeId(ModuleId(1), 1), 2);
assert_eq!("eq", project.get_func_by_id(&eq_func_id).name);
let expected = vec![
Enum {
id: enum_id,
Expand All @@ -2637,7 +2643,7 @@ fn typecheck_enum_declaration() {
EnumVariant { name: "Bar".to_string(), defined_span: Span::new(ModuleId(1), (2, 1), (2, 3)), kind: EnumVariantKind::Constant },
EnumVariant { name: "Baz".to_string(), defined_span: Span::new(ModuleId(1), (3, 1), (3, 3)), kind: EnumVariantKind::Container(baz_func_id) },
],
methods: vec![],
methods: vec![tostring_func_id, hash_func_id, eq_func_id],
static_methods: vec![],
}
];
Expand Down Expand Up @@ -2684,7 +2690,7 @@ fn typecheck_enum_declaration() {
captured_vars: vec![],
captured_closures: vec![],
};
assert_eq!(baz_variant_func, module.scopes[1].funcs[0]);
assert_eq!(baz_variant_func, module.scopes[1].funcs[3]);
}

#[test]
Expand Down
Loading

0 comments on commit 8532ff6

Please sign in to comment.