From 97444994fd726ab3a220eb7bd63b2e8c36a58a15 Mon Sep 17 00:00:00 2001 From: Ken Gorab Date: Wed, 29 Nov 2023 08:55:53 -0500 Subject: [PATCH] More closures LLVM2/Typechecker2: Allow for functions to capture variables from more than just their defined scope; this allows for situations like this to be supported: ``` val a = 1 func outer() { func inner() { println(a) } } ``` --- abra_cli/abra-files/example.abra | 10 ++++---- abra_core/src/typechecker/typechecker2.rs | 23 +++++++++++++++++ abra_llvm/src/compiler2.rs | 31 +++++++++++++++-------- abra_llvm/tests/arrays.abra | 3 +++ abra_llvm/tests/functions.abra | 21 +++++++++++++++ 5 files changed, 72 insertions(+), 16 deletions(-) diff --git a/abra_cli/abra-files/example.abra b/abra_cli/abra-files/example.abra index 1074a6ef..9c019e5e 100644 --- a/abra_cli/abra-files/example.abra +++ b/abra_cli/abra-files/example.abra @@ -1,9 +1,9 @@ // TODO: This shouldn't stackoverflow //println([]) -val one = 1 -func container() { - val arr = [1, 2, 3] - arr.map(i => i + one) +func makeClosureCapturingParam(x: Int): (Int) => Int { + i => i + x } -container() +val closureCapturingParam = makeClosureCapturingParam(1) +/// Expect: 12 +println(closureCapturingParam(11)) diff --git a/abra_core/src/typechecker/typechecker2.rs b/abra_core/src/typechecker/typechecker2.rs index 4a528dd3..b5694c5e 100644 --- a/abra_core/src/typechecker/typechecker2.rs +++ b/abra_core/src/typechecker/typechecker2.rs @@ -5436,6 +5436,29 @@ impl<'a, L: LoadModule> Typechecker2<'a, L> { self.end_child_scope(); self.current_function = prev_func_id; + let func = self.project.get_func_by_id(&lambda_func_id); + if func.is_closure() { + if let Some(current_func_id) = self.current_function { + let FuncId(func_scope_id, _) = current_func_id; + + let captured_vars = func.captured_vars.clone(); + for var_id in &captured_vars { + let VarId(var_scope_id, _) = var_id; + if !self.scope_contains_other(&var_scope_id, &func_scope_id) { + let var = self.project.get_var_by_id_mut(&var_id); + if var.alias == VariableAlias::None { + var.is_captured = true; + + let func = self.project.get_func_by_id_mut(¤t_func_id); + if !func.captured_vars.contains(&var_id) { + func.captured_vars.push(*var_id); + } + } + } + } + } + } + let resolved_type_id = type_hint.unwrap_or(func_type_id); Ok(TypedNode::Lambda { span, func_id: lambda_func_id, type_id: func_type_id, resolved_type_id }) } diff --git a/abra_llvm/src/compiler2.rs b/abra_llvm/src/compiler2.rs index cae64473..31afed9b 100644 --- a/abra_llvm/src/compiler2.rs +++ b/abra_llvm/src/compiler2.rs @@ -684,26 +684,35 @@ impl<'a> LLVMCompiler2<'a> { self.end_abra_main(); } - fn create_closure_captures(&self, function: &Function) -> PointerValue<'a> { + fn create_closure_captures(&self, function: &Function, resolved_generics: &ResolvedGenerics) -> PointerValue<'a> { // Create array of captured variables for a closure. This is implemented as an `i64*`, where each `i64` item is an encoded representation // of the closed-over value. Variables are known to be captured at compile-time, so when they're initialized they're moved to the heap. // When constructing this array, allocate enough memory to hold all known captured variables (each of which will be a pointer), and treat // that pointer as an i64 which is stored in this chunk of memory. Upon retrieval, the value will be converted back into the appropriate // type, which is also known at compile-time. Using an `i64*` as the captures array helps simplify the model behind the scenes, and makes // calling functions/closures simpler. + // + // In addition to captured variables, the captures array _also_ includes any captures arrays of any closures that are closed-over within + // this function (these `i64*` values are encoded as `i64` in the same way as above). Also, it's possible that a closure captures a variable + // from _outside_ the current function scope. In this case, the containing functions must themselves become closures (if they're not already) + // and that captured variable must be carried through the call stack. For example: + // val a = 1 + // func outer() { + // func inner() { println(a) } + // } let malloc_size = self.const_i64((function.captured_vars.len() + function.captured_closures.len() * 8) as u64); let captured_vars_mem = self.malloc(malloc_size, self.closure_captures_t()); for (idx, captured_var_id) in function.captured_vars.iter().enumerate() { let captured_var = self.project.get_var_by_id(captured_var_id); - let llvm_var = self.ctx_stack.last().unwrap().variables.get(&captured_var_id).expect(&format!("No stored slot for variable {} ({:?})", &captured_var.name, &captured_var)); - let captured_var_value = match llvm_var { - LLVMVar::Slot(slot) => { - let val = self.builder.build_load(*slot, &captured_var.name); - debug_assert!(val.is_pointer_value(), "Captured variables should be lifted to heap space upon initialization"); - self.builder.build_ptr_to_int(val.into_pointer_value(), self.i64(), &format!("capture_{}_ptr_as_value", &captured_var.name)) + let val = if let Some(llvm_var) = self.ctx_stack.last().unwrap().variables.get(&captured_var_id) { + match llvm_var { + LLVMVar::Slot(slot) => self.builder.build_load(*slot, &captured_var.name).into_pointer_value(), + LLVMVar::Param(_) => todo!(), } - LLVMVar::Param(_) => todo!(), + } else { + self.get_captured_var_slot(captured_var_id, resolved_generics).unwrap() }; + let captured_var_value = self.builder.build_ptr_to_int(val, self.i64(), &format!("capture_{}_ptr_as_value", &captured_var.name)); let slot = unsafe { self.builder.build_gep(captured_vars_mem, &[self.const_i32(idx as u64).into()], &format!("captured_var_{}_slot", &captured_var.name)) }; self.builder.build_store(slot, captured_var_value); } @@ -797,7 +806,7 @@ impl<'a> LLVMCompiler2<'a> { // If a function captures variables, gather those captures into a captures array, and store as a local. This local // is used later on to invoke a function (if the closure is known statically) or to create a runtime function // value (when a function-aliased identifier is referenced in a non-invocation context). - let captured_vars_mem = self.create_closure_captures(function); + let captured_vars_mem = self.create_closure_captures(function, resolved_generics); let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name); let captured_vars_slot = self.builder.build_alloca(self.closure_captures_t(), &captures_name); self.builder.build_store(captured_vars_slot, captured_vars_mem); @@ -812,7 +821,7 @@ impl<'a> LLVMCompiler2<'a> { for func_id in struct_.methods.iter().chain(&struct_.static_methods) { let function = self.project.get_func_by_id(func_id); if function.is_closure() { - let captured_vars_mem = self.create_closure_captures(function); + let captured_vars_mem = self.create_closure_captures(function, resolved_generics); let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name); let captured_vars_slot = self.builder.build_alloca(self.closure_captures_t(), &captures_name); self.builder.build_store(captured_vars_slot, captured_vars_mem); @@ -1730,7 +1739,7 @@ impl<'a> LLVMCompiler2<'a> { let function = self.project.get_func_by_id(func_id); let captures = if function.is_closure() { - Some(self.create_closure_captures(function)) + Some(self.create_closure_captures(function, resolved_generics)) } else { None }; diff --git a/abra_llvm/tests/arrays.abra b/abra_llvm/tests/arrays.abra index a202ac75..cca640d1 100644 --- a/abra_llvm/tests/arrays.abra +++ b/abra_llvm/tests/arrays.abra @@ -86,6 +86,7 @@ // Array#map func addOne(i: Int): Int = i + 1 func exclaim(i: Int, _: Int, x = "!"): String = "$i$x" +val one = 1 (() => { val arr = [1, 2, 3, 4] @@ -93,6 +94,8 @@ func exclaim(i: Int, _: Int, x = "!"): String = "$i$x" println(arr.map(addOne)) /// Expect: [2, 3, 4, 5] println(arr.map(i => i + 1)) + /// Expect: [2, 3, 4, 5] + println(arr.map(i => i + one)) /// Expect: [1!, 2!, 3!, 4!] println(arr.map(exclaim)) diff --git a/abra_llvm/tests/functions.abra b/abra_llvm/tests/functions.abra index e67bd8bd..b8709410 100644 --- a/abra_llvm/tests/functions.abra +++ b/abra_llvm/tests/functions.abra @@ -183,3 +183,24 @@ println(capturedInt) containsClosures2() /// Expect: 5 println(capturedInt) + +// Returning a function/closure value +func makeNonClosure(): (Int) => Int = i => i + 1 +val nonClosure = makeNonClosure() +/// Expect: 12 +println(nonClosure(11)) + +val one = 1 +func makeClosureCapturingOutside(): (Int) => Int { + i => i + one +} +val closureCapturingOutside = makeClosureCapturingOutside() +/// Expect: 12 +println(closureCapturingOutside(11)) + +//func makeClosureCapturingParam(x: Int): (Int) => Int { +// i => i + x +//} +//val closureCapturingParam = makeClosureCapturingParam(1) +///// Expect: 12 +//println(closureCapturingParam(11))