Skip to content

Commit

Permalink
More closures
Browse files Browse the repository at this point in the history
LLVM2/Typechecker2: Allow for functions to capture variables from more
than just their defined scope; this allows for situations like this to
be supported:
```
val a = 1
func outer() {
  func inner() { println(a) }
}
```
  • Loading branch information
kengorab committed Nov 29, 2023
1 parent 032f223 commit 9744499
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 16 deletions.
10 changes: 5 additions & 5 deletions abra_cli/abra-files/example.abra
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
// TODO: This shouldn't stackoverflow
//println([])

val one = 1
func container() {
val arr = [1, 2, 3]
arr.map(i => i + one)
func makeClosureCapturingParam(x: Int): (Int) => Int {
i => i + x
}
container()
val closureCapturingParam = makeClosureCapturingParam(1)
/// Expect: 12
println(closureCapturingParam(11))
23 changes: 23 additions & 0 deletions abra_core/src/typechecker/typechecker2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5436,6 +5436,29 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> {
self.end_child_scope();
self.current_function = prev_func_id;

let func = self.project.get_func_by_id(&lambda_func_id);
if func.is_closure() {
if let Some(current_func_id) = self.current_function {
let FuncId(func_scope_id, _) = current_func_id;

let captured_vars = func.captured_vars.clone();
for var_id in &captured_vars {
let VarId(var_scope_id, _) = var_id;
if !self.scope_contains_other(&var_scope_id, &func_scope_id) {
let var = self.project.get_var_by_id_mut(&var_id);
if var.alias == VariableAlias::None {
var.is_captured = true;

let func = self.project.get_func_by_id_mut(&current_func_id);
if !func.captured_vars.contains(&var_id) {
func.captured_vars.push(*var_id);
}
}
}
}
}
}

let resolved_type_id = type_hint.unwrap_or(func_type_id);
Ok(TypedNode::Lambda { span, func_id: lambda_func_id, type_id: func_type_id, resolved_type_id })
}
Expand Down
31 changes: 20 additions & 11 deletions abra_llvm/src/compiler2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,26 +684,35 @@ impl<'a> LLVMCompiler2<'a> {
self.end_abra_main();
}

fn create_closure_captures(&self, function: &Function) -> PointerValue<'a> {
fn create_closure_captures(&self, function: &Function, resolved_generics: &ResolvedGenerics) -> PointerValue<'a> {
// Create array of captured variables for a closure. This is implemented as an `i64*`, where each `i64` item is an encoded representation
// of the closed-over value. Variables are known to be captured at compile-time, so when they're initialized they're moved to the heap.
// When constructing this array, allocate enough memory to hold all known captured variables (each of which will be a pointer), and treat
// that pointer as an i64 which is stored in this chunk of memory. Upon retrieval, the value will be converted back into the appropriate
// type, which is also known at compile-time. Using an `i64*` as the captures array helps simplify the model behind the scenes, and makes
// calling functions/closures simpler.
//
// In addition to captured variables, the captures array _also_ includes any captures arrays of any closures that are closed-over within
// this function (these `i64*` values are encoded as `i64` in the same way as above). Also, it's possible that a closure captures a variable
// from _outside_ the current function scope. In this case, the containing functions must themselves become closures (if they're not already)
// and that captured variable must be carried through the call stack. For example:
// val a = 1
// func outer() {
// func inner() { println(a) }
// }
let malloc_size = self.const_i64((function.captured_vars.len() + function.captured_closures.len() * 8) as u64);
let captured_vars_mem = self.malloc(malloc_size, self.closure_captures_t());
for (idx, captured_var_id) in function.captured_vars.iter().enumerate() {
let captured_var = self.project.get_var_by_id(captured_var_id);
let llvm_var = self.ctx_stack.last().unwrap().variables.get(&captured_var_id).expect(&format!("No stored slot for variable {} ({:?})", &captured_var.name, &captured_var));
let captured_var_value = match llvm_var {
LLVMVar::Slot(slot) => {
let val = self.builder.build_load(*slot, &captured_var.name);
debug_assert!(val.is_pointer_value(), "Captured variables should be lifted to heap space upon initialization");
self.builder.build_ptr_to_int(val.into_pointer_value(), self.i64(), &format!("capture_{}_ptr_as_value", &captured_var.name))
let val = if let Some(llvm_var) = self.ctx_stack.last().unwrap().variables.get(&captured_var_id) {
match llvm_var {
LLVMVar::Slot(slot) => self.builder.build_load(*slot, &captured_var.name).into_pointer_value(),
LLVMVar::Param(_) => todo!(),
}
LLVMVar::Param(_) => todo!(),
} else {
self.get_captured_var_slot(captured_var_id, resolved_generics).unwrap()
};
let captured_var_value = self.builder.build_ptr_to_int(val, self.i64(), &format!("capture_{}_ptr_as_value", &captured_var.name));
let slot = unsafe { self.builder.build_gep(captured_vars_mem, &[self.const_i32(idx as u64).into()], &format!("captured_var_{}_slot", &captured_var.name)) };
self.builder.build_store(slot, captured_var_value);
}
Expand Down Expand Up @@ -797,7 +806,7 @@ impl<'a> LLVMCompiler2<'a> {
// If a function captures variables, gather those captures into a captures array, and store as a local. This local
// is used later on to invoke a function (if the closure is known statically) or to create a runtime function
// value (when a function-aliased identifier is referenced in a non-invocation context).
let captured_vars_mem = self.create_closure_captures(function);
let captured_vars_mem = self.create_closure_captures(function, resolved_generics);
let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name);
let captured_vars_slot = self.builder.build_alloca(self.closure_captures_t(), &captures_name);
self.builder.build_store(captured_vars_slot, captured_vars_mem);
Expand All @@ -812,7 +821,7 @@ impl<'a> LLVMCompiler2<'a> {
for func_id in struct_.methods.iter().chain(&struct_.static_methods) {
let function = self.project.get_func_by_id(func_id);
if function.is_closure() {
let captured_vars_mem = self.create_closure_captures(function);
let captured_vars_mem = self.create_closure_captures(function, resolved_generics);
let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name);
let captured_vars_slot = self.builder.build_alloca(self.closure_captures_t(), &captures_name);
self.builder.build_store(captured_vars_slot, captured_vars_mem);
Expand Down Expand Up @@ -1730,7 +1739,7 @@ impl<'a> LLVMCompiler2<'a> {
let function = self.project.get_func_by_id(func_id);

let captures = if function.is_closure() {
Some(self.create_closure_captures(function))
Some(self.create_closure_captures(function, resolved_generics))
} else {
None
};
Expand Down
3 changes: 3 additions & 0 deletions abra_llvm/tests/arrays.abra
Original file line number Diff line number Diff line change
Expand Up @@ -86,13 +86,16 @@
// Array#map
func addOne(i: Int): Int = i + 1
func exclaim(i: Int, _: Int, x = "!"): String = "$i$x"
val one = 1
(() => {
val arr = [1, 2, 3, 4]

/// Expect: [2, 3, 4, 5]
println(arr.map(addOne))
/// Expect: [2, 3, 4, 5]
println(arr.map(i => i + 1))
/// Expect: [2, 3, 4, 5]
println(arr.map(i => i + one))

/// Expect: [1!, 2!, 3!, 4!]
println(arr.map(exclaim))
Expand Down
21 changes: 21 additions & 0 deletions abra_llvm/tests/functions.abra
Original file line number Diff line number Diff line change
Expand Up @@ -183,3 +183,24 @@ println(capturedInt)
containsClosures2()
/// Expect: 5
println(capturedInt)

// Returning a function/closure value
func makeNonClosure(): (Int) => Int = i => i + 1
val nonClosure = makeNonClosure()
/// Expect: 12
println(nonClosure(11))

val one = 1
func makeClosureCapturingOutside(): (Int) => Int {
i => i + one
}
val closureCapturingOutside = makeClosureCapturingOutside()
/// Expect: 12
println(closureCapturingOutside(11))

//func makeClosureCapturingParam(x: Int): (Int) => Int {
// i => i + x
//}
//val closureCapturingParam = makeClosureCapturingParam(1)
///// Expect: 12
//println(closureCapturingParam(11))

0 comments on commit 9744499

Please sign in to comment.