Skip to content

Commit

Permalink
Match statements/expressions: constant cases
Browse files Browse the repository at this point in the history
Add constant cases for match expressions.
  • Loading branch information
kengorab committed Dec 29, 2023
1 parent 2fe2495 commit 087b15f
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 18 deletions.
21 changes: 7 additions & 14 deletions abra_cli/abra-files/example.abra
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,11 @@
//
//Color.Red == Color.Green

func foo(): Int {
val arr = [1, 2, 3]
val i = 4
match arr[i] {
//None v => println("Got: ${v}")
_ v => println("Got: ${v}")
}
val n = match arr[i] {
None => return -4
_ => 100
}
n
val arr = [1, 2, 3, 4]
val n = match arr[1] {
None => -4
2 v => 16 + v
1 => 15
_ => 100
}

foo()
n
55 changes: 51 additions & 4 deletions abra_llvm/src/compiler2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2647,8 +2647,8 @@ impl<'a> LLVMCompiler2<'a> {
TypedMatchCaseKind::None => {
seen_none_case = true;
let is_none_bb = self.context.append_basic_block(self.current_fn.0, "is_none");
let else_bb = self.context.append_basic_block(self.current_fn.0, "is_not_none");
self.builder.build_conditional_branch(data_is_set, else_bb, is_none_bb);
let next_case_bb = self.context.append_basic_block(self.current_fn.0, "next_case");
self.builder.build_conditional_branch(data_is_set, next_case_bb, is_none_bb);

self.builder.position_at_end(is_none_bb);
if let Some(var_id) = &case.case_binding {
Expand All @@ -2675,7 +2675,7 @@ impl<'a> LLVMCompiler2<'a> {
self.builder.build_unconditional_branch(end_bb);
}

self.builder.position_at_end(else_bb);
self.builder.position_at_end(next_case_bb);
}
TypedMatchCaseKind::Wildcard(_) => {
if let Some(var_id) = &case.case_binding {
Expand Down Expand Up @@ -2703,7 +2703,54 @@ impl<'a> LLVMCompiler2<'a> {
}
}
TypedMatchCaseKind::Type(_, _) => todo!(),
TypedMatchCaseKind::Constant(_, _) => todo!(),
TypedMatchCaseKind::Constant(constant_node_type_id, constant_node) => {
let next_case_bb = self.context.append_basic_block(self.current_fn.0, "next_case");

let target_type_id = if let Some(inner_type_id) = self.type_is_option(target_type_id) {
if !seen_none_case {
let cont_bb = self.context.append_basic_block(self.current_fn.0, "cont");

self.builder.build_conditional_branch(data_is_set, cont_bb, next_case_bb);
self.builder.position_at_end(cont_bb);
}

inner_type_id
} else {
*target_type_id
};

let is_eq_bb = self.context.append_basic_block(self.current_fn.0, "is_eq");
let constant_value = self.visit_expression(constant_node, &resolved_generics).unwrap();
let eq = self.compile_eq(false, &target_type_id, data, constant_node_type_id, constant_value, &resolved_generics);
self.builder.build_conditional_branch(eq, is_eq_bb, next_case_bb);

self.builder.position_at_end(is_eq_bb);
if let Some(var_id) = &case.case_binding {
let expr_val = data;
let var = self.project.get_var_by_id(var_id);
let pat = BindingPattern::Variable(Token::Ident(var.defined_span.as_ref().unwrap().range.start.clone(), var.name.clone()));
self.compile_binding_declaration(&pat, &vec![*var_id], Some(expr_val), &resolved_generics);
}

let mut case_value = None;
let case_body_len = case.body.len();
for (idx, node) in case.body.iter().enumerate() {
if idx == case_body_len - 1 {
case_value = self.visit_statement(node, &resolved_generics);
} else {
self.visit_statement(node, &resolved_generics);
}
}
if case.block_terminator.is_none() {
if let Some(result_slot) = result_slot {
let case_value = case_value.expect("If we're able to treat the match as an expression, then the resulting value exists");
self.builder.build_store(result_slot, case_value);
}
self.builder.build_unconditional_branch(end_bb);
}

self.builder.position_at_end(next_case_bb);
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions abra_llvm/tests/match.abra
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
}

// With case binding and _no_ `None` case

/// Expect: value: 2
match arr[1] {
_ v => println("value: ${v}")
Expand All @@ -42,6 +43,49 @@
match arr[14] {
_ v => println("value: ${v}")
}

// With constant cases

/// Expect: 18
match arr[1] {
2 v => println(16 + v)
1 => println(15)
None => println(-4)
_ => println(100)
}
/// Expect: 18
match arr[1] {
None => println(-4)
2 v => println(16 + v)
1 => println(15)
_ => println(100)
}
/// Expect: 18
match arr[1] {
2 v => println(16 + v)
1 => println(15)
_ => println(100)
}
/// Expect: -4
match arr[5] {
2 v => println(16 + v)
1 => println(15)
None => println(-4)
_ => println(100)
}
/// Expect: -4
match arr[5] {
None => println(-4)
2 v => println(16 + v)
1 => println(15)
_ => println(100)
}
/// Expect: 100
match arr[5] {
2 v => println(16 + v)
1 => println(15)
_ => println(100)
}
})()

// Testing match as expression
Expand Down

0 comments on commit 087b15f

Please sign in to comment.