From bc7c33260828973e1e99db56817f5eb4ded7fe50 Mon Sep 17 00:00:00 2001 From: Ken Gorab Date: Sun, 26 Nov 2023 11:05:46 -0500 Subject: [PATCH] Closures LLVM2: Implement closure support for function values. If a function has captured variables, we need to create a captures array at compile-time, and we also need to pass that array into the function when it's called. --- abra_cli/abra-files/example.abra | 20 ++- abra_llvm/src/compiler2.rs | 278 +++++++++++++++++++++++++------ abra_llvm/tests/functions.abra | 36 ++++ 3 files changed, 279 insertions(+), 55 deletions(-) diff --git a/abra_cli/abra-files/example.abra b/abra_cli/abra-files/example.abra index 25c2de3a..7ba2f40a 100644 --- a/abra_cli/abra-files/example.abra +++ b/abra_cli/abra-files/example.abra @@ -1,9 +1,19 @@ // TODO: This shouldn't stackoverflow //println([]) -val arr = [1, 2, 3, 4] -//func sum(acc: Int, i: Int): Int = acc + i -//println(arr.reduce(0, (acc, i) => acc + i)) +var a = 1.1 +val arr = [1, 2, 3] +func foo(): Float { + //val x = a + 1 + arr.pop() + a += 1.1 + //x +} -val fn = (x: Int) => x + 1 -println(fn(123)) +a = 10.1 + +println("arr:", arr) +println("a:", a) +println("foo():", foo()) +println("a:", a) +println("arr:", arr) diff --git a/abra_llvm/src/compiler2.rs b/abra_llvm/src/compiler2.rs index 84248c88..1be223d5 100644 --- a/abra_llvm/src/compiler2.rs +++ b/abra_llvm/src/compiler2.rs @@ -131,7 +131,7 @@ pub struct LLVMCompiler2<'a> { context: &'a Context, builder: Builder<'a>, main_module: Module<'a>, - current_fn: FunctionValue<'a>, + current_fn: (FunctionValue<'a>, Option), ctx_stack: Vec>, // cached for convenience @@ -215,7 +215,7 @@ impl<'a> LLVMCompiler2<'a> { context, builder, main_module, - current_fn: abra_main_fn, + current_fn: (abra_main_fn, None), ctx_stack: vec![CompilerContext::default()], // cached values @@ -279,6 +279,10 @@ impl<'a> LLVMCompiler2<'a> { self.ptr(self.i8()).const_null() } + fn closure_captures_t(&self) -> PointerType<'a> { + self.ptr(self.i64()) + } + fn fn_type>(&self, ret: T, param_types: &[BasicMetadataTypeEnum<'a>]) -> FunctionType<'a> { ret.fn_type(param_types, false) } @@ -411,7 +415,7 @@ impl<'a> LLVMCompiler2<'a> { } } Type::GenericEnumInstance(_, _, _) => todo!(), - Type::Function(_, _, _, _) => self.make_function_value_type_by_type(ty, resolved_generics).as_basic_type_enum(), + Type::Function(_, _, _, _) => self.make_function_value_type_by_type(ty, resolved_generics).0.as_basic_type_enum(), Type::Type(_) | Type::ModuleAlias => todo!() }; @@ -500,6 +504,9 @@ impl<'a> LLVMCompiler2<'a> { if num_optional_params > 16 { unimplemented!("A function can have at most 16 optional parameters currently"); } params.push(self.i16().into()); } + if !function.captured_vars.is_empty() { + params.insert(0, self.closure_captures_t().into()); + } if function.return_type_id == PRELUDE_UNIT_TYPE_ID { self.context.void_type().fn_type(params.as_slice(), false) @@ -594,7 +601,7 @@ impl<'a> LLVMCompiler2<'a> { } fn end_abra_main(&self) { - let abra_main_fn = self.current_fn; + let (abra_main_fn, _) = self.current_fn; debug_assert!(abra_main_fn.get_name().to_str().unwrap() == ABRA_MAIN_FN_NAME); let b = abra_main_fn.get_last_basic_block().expect("abra_main is guaranteed to have >=1 block"); self.builder.position_at_end(b); @@ -628,7 +635,7 @@ impl<'a> LLVMCompiler2<'a> { let mod_fn_type = self.fn_type(self.bool(), &[]); let mod_fn = self.main_module.add_function(&mod_fn_name, mod_fn_type, None); let prev_fn = self.current_fn; - self.current_fn = mod_fn; + self.current_fn = (mod_fn, None); let block = self.context.append_basic_block(mod_fn, ""); self.builder.position_at_end(block); @@ -662,20 +669,90 @@ impl<'a> LLVMCompiler2<'a> { self.builder.build_return(Some(&self.const_bool(true))); self.current_fn = prev_fn; - debug_assert!(self.current_fn.get_name().to_str().unwrap() == ABRA_MAIN_FN_NAME); + let (current_fn, _) = self.current_fn; + debug_assert!(current_fn.get_name().to_str().unwrap() == ABRA_MAIN_FN_NAME); // Call the $mod_{mod_id} fn in the _abra_main fn. - self.builder.position_at_end(self.current_fn.get_last_basic_block().unwrap()); + self.builder.position_at_end(current_fn.get_last_basic_block().unwrap()); self.builder.build_call(mod_fn, &[], ""); } self.end_abra_main(); } + fn create_closure_captures(&self, function: &Function) -> PointerValue<'a> { + // Create array of captured variables for a closure. This is implemented as an `i64*`, where each `i64` item is an encoded representation + // of the closed-over value. Variables are known to be captured at compile-time, so when they're initialized they're moved to the heap. + // When constructing this array, allocate enough memory to hold all known captured variables (each of which will be a pointer), and treat + // that pointer as an i64 which is stored in this chunk of memory. Upon retrieval, the value will be converted back into the appropriate + // type, which is also known at compile-time. Using an `i64*` as the captures array helps simplify the model behind the scenes, and makes + // calling functions/closures simpler. + let malloc_size = self.const_i64((function.captured_vars.len() * 8) as u64); + let captured_vars_mem = self.malloc(malloc_size, self.closure_captures_t()); + for (idx, captured_var_id) in function.captured_vars.iter().enumerate() { + let captured_var = self.project.get_var_by_id(captured_var_id); + let llvm_var = self.ctx_stack.last().unwrap().variables.get(&captured_var_id).expect(&format!("No stored slot for variable {} ({:?})", &captured_var.name, &captured_var)); + let captured_var_value = match llvm_var { + LLVMVar::Slot(slot) => { + let val = self.builder.build_load(*slot, &captured_var.name); + debug_assert!(val.is_pointer_value(), "Captured variables should be lifted to heap space upon initialization"); + self.builder.build_ptr_to_int(val.into_pointer_value(), self.i64(), &format!("capture_{}_ptr_as_value", &captured_var.name)) + } + LLVMVar::Param(_) => todo!(), + }; + let slot = unsafe { self.builder.build_gep(captured_vars_mem, &[self.const_i32(idx as u64).into()], &format!("captured_var_{}_slot", &captured_var.name)) }; + self.builder.build_store(slot, captured_var_value); + } + + captured_vars_mem + } + + fn get_captured_var_slot(&self, var_id: &VarId, resolved_generics: &ResolvedGenerics) -> Option> { + // See `self.create_closure_captures` for more explanation of the underlying data model for the captures array. + // When retrieving a captured variable, we expect that we are in a function context, and that the variable being + // resolved, if it's a capture of that function, will be included in the function's `captured_vars` list. If so, + // retrieve it by index (known at compile-time). This value will be an `i64`, since the captures is of type `i64*`, + // so we need to decode that `i64` back into a pointer of the appropriate type for that captured variable. + let variable = self.project.get_var_by_id(var_id); + if variable.is_captured { + if let Some(func_id) = self.current_fn.1 { + let current_function = self.project.get_func_by_id(&func_id); + if let Some((captured_var_idx, _)) = current_function.captured_vars.iter().find_position(|v| v == &var_id) { + let current_func_captures_arg = self.current_fn.0.get_nth_param(0).unwrap().into_pointer_value(); + let captured_arg_slot = unsafe { self.builder.build_gep(current_func_captures_arg, &[self.const_i32(captured_var_idx as u64).into()], &format!("captured_arg_{}_slot", &variable.name)) }; + let encoded_captured_arg = self.builder.build_load(captured_arg_slot, &format!("captured_arg_{}", &variable.name)).into_int_value(); + + let Some(llvm_type) = self.llvm_underlying_type_by_id(&variable.type_id, resolved_generics) else { todo!() }; + let llvm_type = self.llvm_ptr_wrap_type_if_needed(llvm_type); + let ptr = self.builder.build_int_to_ptr(encoded_captured_arg, self.ptr(llvm_type), ""); + return Some(ptr); + } + } + } + + None + } + fn visit_statement(&mut self, node: &TypedNode, resolved_generics: &ResolvedGenerics) -> Option> { match node { node @ TypedNode::If { .. } => self.visit_if_node(node, resolved_generics), TypedNode::Match { .. } => todo!(), - TypedNode::FuncDeclaration(_) => None, + TypedNode::FuncDeclaration(func_id) => { + let function = self.project.get_func_by_id(func_id); + if !function.captured_vars.is_empty() { + // If a function captures variables, gather those captures into a captures array, and store as global (this works + // for now because at the moment all functions are defined at the top-level, but this will need to change in order + // to support nested functions, methods, and lambda expressions). This global is used later on to create a runtime + // function value, when a function-aliased identifier is referenced in a non-invocation context. + let captured_vars_mem = self.create_closure_captures(function); + let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name); + let global = self.main_module.add_global(self.ptr(self.i64()), None, &captures_name); + global.set_constant(false); + global.set_initializer(&self.closure_captures_t().const_null()); + self.builder.build_store(global.as_pointer_value(), captured_vars_mem); + } + + None + } TypedNode::TypeDeclaration(_) => None, TypedNode::EnumDeclaration(_) => None, TypedNode::BindingDeclaration { vars, pattern, expr, .. } => { @@ -691,10 +768,24 @@ impl<'a> LLVMCompiler2<'a> { if let Some(expr_val) = expr_val { let llvm_type = self.llvm_underlying_type_by_id(&variable.type_id, resolved_generics).unwrap_or_else(|| expr_val.get_type().as_basic_type_enum()); let llvm_type = self.llvm_ptr_wrap_type_if_needed(llvm_type); - let slot = self.builder.build_alloca(llvm_type, &var_name); - self.ctx_stack.last_mut().unwrap().variables.insert(variable.id, LLVMVar::Slot(slot)); - self.builder.build_store(slot, expr_val); + // If variable is captured, move value to heap so its lifetime extends beyond the current stack frame. There is specific logic + // to handle references to the variable later on (see TypedNode::Identifier and TypedNode::Assignment logic). + if variable.is_captured { + let ptr_type = llvm_type.ptr_type(AddressSpace::Generic); + let heap_mem = self.malloc(self.const_i64(8), ptr_type); + self.builder.build_store(heap_mem, expr_val); + + let slot = self.builder.build_alloca(ptr_type, &var_name); + self.ctx_stack.last_mut().unwrap().variables.insert(variable.id, LLVMVar::Slot(slot)); + + self.builder.build_store(slot, heap_mem); + } else { + let slot = self.builder.build_alloca(llvm_type, &var_name); + self.ctx_stack.last_mut().unwrap().variables.insert(variable.id, LLVMVar::Slot(slot)); + + self.builder.build_store(slot, expr_val); + } } None @@ -704,9 +795,9 @@ impl<'a> LLVMCompiler2<'a> { debug_assert!(condition_var_id.is_none(), "Not implemented yet"); debug_assert!(condition.as_ref().type_id() == &PRELUDE_BOOL_TYPE_ID, "Only implement while loops for boolean conditions for now (no Optionals yet)"); - let loop_cond_block = self.context.append_basic_block(self.current_fn, "while_loop_cond"); - let loop_body_block = self.context.append_basic_block(self.current_fn, "while_loop_body"); - let loop_end_block = self.context.append_basic_block(self.current_fn, "while_loop_end"); + let loop_cond_block = self.context.append_basic_block(self.current_fn.0, "while_loop_cond"); + let loop_body_block = self.context.append_basic_block(self.current_fn.0, "while_loop_body"); + let loop_end_block = self.context.append_basic_block(self.current_fn.0, "while_loop_end"); self.builder.build_unconditional_branch(loop_cond_block); self.builder.position_at_end(loop_cond_block); @@ -965,9 +1056,9 @@ impl<'a> LLVMCompiler2<'a> { let left_val = self.visit_expression(left, resolved_generics).unwrap(); let op_name = if op == &BinaryOp::And { "and" } else { "or" }; - let then_bb = self.context.append_basic_block(self.current_fn, &format!("binary_{op_name}_then")); - let else_bb = self.context.append_basic_block(self.current_fn, &format!("binary_{op_name}_else")); - let cont_bb = self.context.append_basic_block(self.current_fn, &format!("binary_{op_name}_cont")); + let then_bb = self.context.append_basic_block(self.current_fn.0, &format!("binary_{op_name}_then")); + let else_bb = self.context.append_basic_block(self.current_fn.0, &format!("binary_{op_name}_else")); + let cont_bb = self.context.append_basic_block(self.current_fn.0, &format!("binary_{op_name}_cont")); let cond = self.builder.build_int_compare(IntPredicate::EQ, left_val.into_int_value(), self.const_bool(true), ""); self.builder.build_conditional_branch(cond, then_bb, else_bb); @@ -1096,17 +1187,45 @@ impl<'a> LLVMCompiler2<'a> { TypedNode::Map { .. } => todo!(), TypedNode::Identifier { var_id, resolved_type_id, .. } => { let variable = self.project.get_var_by_id(var_id); + let value = match variable.alias { VariableAlias::None => { + // If the variable is captured, then the underlying model is different and needs to be handled specially. + // Captured variables are moved to the heap upon initialization (see TypedNode::BindingDeclaration), and + // their backing type uses pointer indirection. If we're currently in a function which closes over the + // variable, we handle it here... + if let Some(ptr) = self.get_captured_var_slot(var_id, resolved_generics) { + let decoded_captured_val = self.builder.build_load(ptr, ""); + return Some(decoded_captured_val); + } + let llvm_var = self.ctx_stack.last().unwrap().variables.get(&variable.id).expect(&format!("No stored slot for variable {} ({:?})", &variable.name, &variable)); match llvm_var { - LLVMVar::Slot(slot) => self.builder.build_load(*slot, &variable.name), + LLVMVar::Slot(slot) => { + let val = self.builder.build_load(*slot, &variable.name); + // ...otherwise, we need to handle it here. If the variable is captured by some closure, then it'll + // be represented as a pointer value which needs to be dereferenced upon access. + if variable.is_captured { + self.builder.build_load(val.into_pointer_value(), "") + } else { + val + } + } LLVMVar::Param(value) => *value, } } VariableAlias::Function(func_id) => { // todo: cache value to not create duplicate? - self.make_function_value(&func_id, resolved_type_id, resolved_generics) + let function = self.project.get_func_by_id(&func_id); + let captures = if !function.captured_vars.is_empty() { + let captures_name = format!("captures_{}_{}_{}_{}", func_id.0.0.0, func_id.0.1, func_id.1, &function.name); + let captures_ptr = self.main_module.get_global(&captures_name).unwrap().as_pointer_value(); + Some(self.builder.build_load(captures_ptr, "").into_pointer_value()) + } else { + None + }; + + self.make_function_value(&func_id, resolved_type_id, captures, resolved_generics) } VariableAlias::Type(_) => todo!() }; @@ -1140,12 +1259,43 @@ impl<'a> LLVMCompiler2<'a> { params_data = function.params.iter().map(|p| (p.type_id, p.default_value.is_some())).collect_vec(); new_resolved_generics = resolved_generics.new_via_func_call(function, type_arg_ids, &self.project); + if let Some(dec) = function.decorators.iter().find(|dec| dec.name == "Intrinsic") { let TypedNode::Literal { value: TypedLiteral::String(intrinsic_name), .. } = &dec.args[0] else { unreachable!("@Intrinsic requires exactly 1 String argument") }; return self.compile_intrinsic_invocation(type_arg_ids, &new_resolved_generics, intrinsic_name, Some(&**target), arguments, resolved_type_id); } - self.get_or_compile_function(func_id, &new_resolved_generics).into() + if !function.captured_vars.is_empty() { + let fn_obj = self.visit_expression(target, resolved_generics).unwrap().into_pointer_value(); + + let captures_slot = self.builder.build_struct_gep(fn_obj, 0, "captures_slot").unwrap(); + let captures_arr = self.builder.build_load(captures_slot, "captures"); + args.push(captures_arr.into()); + + let fn_ptr_slot = self.builder.build_struct_gep(fn_obj, 1, "fn_ptr_slot").unwrap(); + let fn_ptr = self.builder.build_load(fn_ptr_slot, "fn_ptr").into_pointer_value(); + + // The fn_ptr of a Function value is typed in such a way that the `captures` parameter is lost. If we know it's a closure (which + // we can determine programmatically if necessary based on whether there's a non-NULL `captures` value), we need to programmatically + // cast this function pointer to a different type (namely, the same signature, but with an `i64*` as the first parameter). Here + // though, we don't need to do it programmatically since we know here at compile-time that the function is a closure. When all we + // have is a function value though (see wildcard case further on), this is where we'll need to do the programmatic check. + let underlying_fn_type = fn_ptr.get_type().get_element_type().into_function_type(); + let underlying_fn_param_types = underlying_fn_type.get_param_types(); + let mut new_fn_param_types = vec![self.closure_captures_t().into()]; + new_fn_param_types.extend(underlying_fn_param_types); + let new_fn_param_types = new_fn_param_types.into_iter().map(|t| t.into()).collect_vec(); + let new_fn_type = if let Some(ret_type) = underlying_fn_type.get_return_type() { + ret_type.fn_type(&new_fn_param_types.as_slice(), false) + } else { + self.context.void_type().fn_type(&new_fn_param_types.as_slice(), false) + }; + let fn_ptr = self.builder.build_pointer_cast(fn_ptr, new_fn_type.ptr_type(AddressSpace::Generic), ""); + + CallableValue::try_from(fn_ptr).unwrap() + } else { + self.get_or_compile_function(func_id, &new_resolved_generics).into() + } } VariableAlias::Type(TypeKind::Enum(_)) => unreachable!("Cannot invoke an enum directly"), VariableAlias::Type(TypeKind::Struct(struct_id)) => { @@ -1247,7 +1397,7 @@ impl<'a> LLVMCompiler2<'a> { new_resolved_generics = resolved_generics.clone(); let fn_obj = self.visit_expression(target, resolved_generics).unwrap().into_pointer_value(); - let fn_ptr_slot = self.builder.build_struct_gep(fn_obj, 0, "fn_ptr_slot").unwrap(); + let fn_ptr_slot = self.builder.build_struct_gep(fn_obj, 1, "fn_ptr_slot").unwrap(); let fn_ptr = self.builder.build_load(fn_ptr_slot, "fn_ptr").into_pointer_value(); CallableValue::try_from(fn_ptr).unwrap() } @@ -1439,7 +1589,7 @@ impl<'a> LLVMCompiler2<'a> { } } TypedNode::Lambda { func_id, resolved_type_id, .. } => { - Some(self.make_function_value(func_id, resolved_type_id, resolved_generics)) + Some(self.make_function_value(func_id, resolved_type_id, None, resolved_generics)) } TypedNode::Assignment { kind, expr, .. } => { let expr_val = self.visit_expression(expr, resolved_generics).unwrap(); @@ -1447,9 +1597,28 @@ impl<'a> LLVMCompiler2<'a> { match kind { AssignmentKind::Identifier { var_id } => { let variable = self.project.get_var_by_id(var_id); + + // If the variable is captured, then the underlying model is different and needs to be handled specially. + // Captured variables are moved to the heap upon initialization (see TypedNode::BindingDeclaration), and + // their backing type uses pointer indirection. If we're currently in a function which closes over the + // variable, we handle it here... + if let Some(captured_arg_ptr) = self.get_captured_var_slot(var_id, resolved_generics) { + self.builder.build_store(captured_arg_ptr, expr_val); + return Some(expr_val); + } + let llvm_var = self.ctx_stack.last().unwrap().variables.get(var_id).expect(&format!("No known llvm variable for variable '{}' ({:?})", &variable.name, var_id)); match llvm_var { - LLVMVar::Slot(ptr) => { self.builder.build_store(*ptr, expr_val); } + LLVMVar::Slot(ptr) => { + // ...otherwise, we need to handle it here. If the variable is captured by some closure, then it'll + // be represented as a pointer value which needs to be dereferenced upon store. + if variable.is_captured { + let captured_var_ptr = self.builder.build_load(*ptr, "").into_pointer_value(); + self.builder.build_store(captured_var_ptr, expr_val); + } else { + self.builder.build_store(*ptr, expr_val); + } + } LLVMVar::Param(_) => unreachable!("Parameters are not assignable") } } @@ -1489,9 +1658,9 @@ impl<'a> LLVMCompiler2<'a> { debug_assert!(condition_binding.is_none(), "Condition bindings not yet implemented"); debug_assert!(condition.as_ref().type_id() == &PRELUDE_BOOL_TYPE_ID, "Only implement if-statements for boolean conditions for now (no Optionals yet)"); - let then_bb = self.context.append_basic_block(self.current_fn, "then_block"); - let else_bb = self.context.append_basic_block(self.current_fn, "else_block"); - let end_bb = self.context.append_basic_block(self.current_fn, "if_end"); + let then_bb = self.context.append_basic_block(self.current_fn.0, "then_block"); + let else_bb = self.context.append_basic_block(self.current_fn.0, "else_block"); + let end_bb = self.context.append_basic_block(self.current_fn.0, "if_end"); let cond_val = self.visit_expression(&condition, resolved_generics).unwrap().into_int_value(); let cmp = self.builder.build_int_compare(IntPredicate::EQ, cond_val, self.const_bool(true), ""); @@ -1873,7 +2042,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, None); self.ctx_stack.push(CompilerContext { variables: HashMap::new() }); let block = self.context.append_basic_block(llvm_fn, ""); @@ -1934,13 +2103,17 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); self.ctx_stack.push(CompilerContext { variables: HashMap::new() }); let block = self.context.append_basic_block(llvm_fn, ""); self.builder.position_at_end(block); let mut params_iter = llvm_fn.get_param_iter(); + if !function.captured_vars.is_empty() { + params_iter.next().unwrap().set_name("captures"); + } + let mut default_value_param_idx = 0; for (idx, param) in function.params.iter().enumerate() { params_iter.next().unwrap().set_name(¶m.name); @@ -2072,7 +2245,7 @@ impl<'a> LLVMCompiler2<'a> { let llvm_fn = self.main_module.add_function(&initializer_sig, fn_type, None); let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, None); let block = self.context.append_basic_block(llvm_fn, ""); self.builder.position_at_end(block); @@ -2155,7 +2328,7 @@ impl<'a> LLVMCompiler2<'a> { let llvm_fn = self.main_module.add_function(&initializer_sig, fn_type, None); let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); let mut params_iter = llvm_fn.get_param_iter(); params_iter.next().unwrap().set_name("self"); @@ -2221,7 +2394,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); llvm_fn.get_param_iter().next().unwrap().set_name("self"); let block = self.context.append_basic_block(llvm_fn, ""); @@ -2252,7 +2425,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); llvm_fn.get_param_iter().next().unwrap().set_name("self"); let block = self.context.append_basic_block(llvm_fn, ""); @@ -2292,7 +2465,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); llvm_fn.get_param_iter().next().unwrap().set_name("self"); let block = self.context.append_basic_block(llvm_fn, ""); @@ -2330,7 +2503,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = llvm_fn; + self.current_fn = (llvm_fn, Some(*func_id)); llvm_fn.get_param_iter().next().unwrap().set_name("self"); let block = self.context.append_basic_block(llvm_fn, ""); @@ -2343,7 +2516,7 @@ impl<'a> LLVMCompiler2<'a> { llvm_fn } - fn make_function_value(&mut self, func_id: &FuncId, target_type_id: &TypeId, resolved_generics: &ResolvedGenerics) -> BasicValueEnum<'a> { + fn make_function_value(&mut self, func_id: &FuncId, target_type_id: &TypeId, captures_arr: Option>, resolved_generics: &ResolvedGenerics) -> BasicValueEnum<'a> { let Type::Function(target_param_type_ids, target_num_required_params, _, target_return_type_id) = self.project.get_type_by_id(target_type_id) else { unreachable!() }; let target_arity = *target_num_required_params; debug_assert!(target_param_type_ids.len() == target_arity); @@ -2384,7 +2557,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = wrapper_llvm_fn; + self.current_fn = (wrapper_llvm_fn, None); self.ctx_stack.push(CompilerContext { variables: HashMap::new() }); let block = self.context.append_basic_block(wrapper_llvm_fn, ""); @@ -2438,7 +2611,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = wrapper_llvm_fn; + self.current_fn = (wrapper_llvm_fn, None); self.ctx_stack.push(CompilerContext { variables: HashMap::new() }); let block = self.context.append_basic_block(wrapper_llvm_fn, ""); @@ -2488,7 +2661,7 @@ impl<'a> LLVMCompiler2<'a> { let prev_bb = self.builder.get_insert_block().unwrap(); let prev_fn = self.current_fn; - self.current_fn = wrapper_llvm_fn; + self.current_fn = (wrapper_llvm_fn, None); self.ctx_stack.push(CompilerContext { variables: HashMap::new() }); let block = self.context.append_basic_block(wrapper_llvm_fn, ""); @@ -2513,41 +2686,46 @@ impl<'a> LLVMCompiler2<'a> { unreachable!("All prior cases should have been exhausted above") }; - let fn_value_type = self.make_function_value_type_by_type_id(target_type_id, resolved_generics); + let (fn_value_type, fn_value_fn_ptr_type) = self.make_function_value_type_by_type_id(target_type_id, resolved_generics); let mem = self.malloc(self.sizeof_struct(fn_value_type), fn_value_type.ptr_type(AddressSpace::Generic)); - let fn_ptr_slot = self.builder.build_struct_gep(mem, 0, "fn_ptr_slot").unwrap(); - self.builder.build_store(fn_ptr_slot, llvm_fn.as_global_value().as_pointer_value()); + let captures_slot = self.builder.build_struct_gep(mem, 0, "captures_slot").unwrap(); + let captures_val = captures_arr.unwrap_or(self.ptr(self.i64()).const_null()); + self.builder.build_store(captures_slot, captures_val); + + let fn_ptr_slot = self.builder.build_struct_gep(mem, 1, "fn_ptr_slot").unwrap(); + let llvm_fn_ptr = llvm_fn.as_global_value().as_pointer_value(); + let llvm_fn_ptr = self.builder.build_pointer_cast(llvm_fn_ptr, fn_value_fn_ptr_type, ""); + self.builder.build_store(fn_ptr_slot, llvm_fn_ptr); mem.as_basic_value_enum() } - fn make_function_value_type_by_type_id(&self, func_type_id: &TypeId, resolved_generics: &ResolvedGenerics) -> StructType<'a> { + fn make_function_value_type_by_type_id(&self, func_type_id: &TypeId, resolved_generics: &ResolvedGenerics) -> (StructType<'a>, PointerType<'a>) { self.make_function_value_type_by_type(self.project.get_type_by_id(func_type_id), resolved_generics) } - fn make_function_value_type_by_type(&self, func_ty: &Type, resolved_generics: &ResolvedGenerics) -> StructType<'a> { + fn make_function_value_type_by_type(&self, func_ty: &Type, resolved_generics: &ResolvedGenerics) -> (StructType<'a>, PointerType<'a>) { let Type::Function(param_type_ids, num_required_params, is_variadic, return_type_id) = func_ty else { unreachable!() }; let fn_value_type_name = self.llvm_type_name_by_type(func_ty, resolved_generics); + let llvm_fn_type = self.llvm_function_type_by_parts(param_type_ids, *num_required_params, *is_variadic, return_type_id, resolved_generics); + let fn_ptr_type = llvm_fn_type.ptr_type(AddressSpace::Generic); + if let Some(llvm_type) = self.main_module.get_struct_type(&fn_value_type_name) { - llvm_type + (llvm_type, fn_ptr_type) } else { - let llvm_fn_type = self.llvm_function_type_by_parts(param_type_ids, *num_required_params, *is_variadic, return_type_id, resolved_generics); - let fn_val_type = self.context.opaque_struct_type(&fn_value_type_name); - let fn_ptr_type = llvm_fn_type.ptr_type(AddressSpace::Generic); fn_val_type.set_body(&[ // self.i32().into(), // param_trait_flag // self.i32().array_type(arity as u32).into(), // param_type_ids - // self.i8().into(), // num_captures - // self.context.i64_type().ptr_type(AddressSpace::Generic).into(), // captures + self.context.i64_type().ptr_type(AddressSpace::Generic).into(), // captures fn_ptr_type.into(), // fn_ptr ], false); - fn_val_type + (fn_val_type, fn_ptr_type) } } } diff --git a/abra_llvm/tests/functions.abra b/abra_llvm/tests/functions.abra index 0ab6934c..e355f313 100644 --- a/abra_llvm/tests/functions.abra +++ b/abra_llvm/tests/functions.abra @@ -52,3 +52,39 @@ callFn2(f6) func f7(x: Int, y = 12, z = 100): Int = x + y + z /// Expect: 51 callFn3(f7) + +// Closures +var capturedFloat = 1.1 +func closure1(): Float { + capturedFloat += 1.1 + val x = capturedFloat + 1 + x +} +/// Expect: 1.1 +println(capturedFloat) +/// Expect: 3.2 +println(closure1()) +/// Expect: 2.2 +println(capturedFloat) + +capturedFloat = 10.1 +/// Expect: 10.1 +println(capturedFloat) +/// Expect: 12.2 +println(closure1()) +/// Expect: 11.2 +println(capturedFloat) + +val capturedArray = [1, 2, 3] +func closure2() { + capturedArray.pop() +} +/// Expect: [1, 2, 3] +println(capturedArray) +closure2() +/// Expect: [1, 2] +println(capturedArray) +capturedArray.push(3) +/// Expect: [1, 2, 3] +println(capturedArray) +