diff --git a/dora-frontend/src/generator.rs b/dora-frontend/src/generator.rs index a01189328..217209462 100644 --- a/dora-frontend/src/generator.rs +++ b/dora-frontend/src/generator.rs @@ -718,34 +718,25 @@ impl<'a> AstBytecodeGen<'a> { let variant = &enum_.variants[variant_idx as usize]; - iterate_subpatterns( - self.analysis, - pattern, - variant.fields.len(), - |idx, param| { - let element_ty = variant.fields[idx].parsed_type.ty(); - let element_ty = specialize_type(self.sa, element_ty, enum_type_params); - let ty = register_bty_from_ty(element_ty.clone()); - let field_reg = self.alloc_temp(ty); - - let idx = self.builder.add_const_enum_element( - bc_enum_id, - bc_enum_type_params.clone(), - variant_idx, - idx as u32, - ); + iterate_subpatterns(self.analysis, pattern, |idx, param| { + let element_ty = variant.fields[idx].parsed_type.ty(); + let element_ty = specialize_type(self.sa, element_ty, enum_type_params); + let ty = register_bty_from_ty(element_ty.clone()); + let field_reg = self.alloc_temp(ty); + + let idx = self.builder.add_const_enum_element( + bc_enum_id, + bc_enum_type_params.clone(), + variant_idx, + idx as u32, + ); - self.builder.emit_load_enum_element( - field_reg, - value, - idx, - self.loc(pattern.span()), - ); + self.builder + .emit_load_enum_element(field_reg, value, idx, self.loc(pattern.span())); - self.destruct_pattern_inner(pck, ¶m.pattern, field_reg, element_ty); - self.free_temp(field_reg); - }, - ); + self.destruct_pattern_inner(pck, ¶m.pattern, field_reg, element_ty); + self.free_temp(field_reg); + }); self.free_temp(actual_variant_reg); self.free_temp(expected_variant_reg); @@ -763,25 +754,20 @@ impl<'a> AstBytecodeGen<'a> { ) { let struct_ = self.sa.struct_(struct_id); - iterate_subpatterns( - self.analysis, - pattern, - struct_.fields.len(), - |idx, field| { - let field_ty = struct_.fields[idx].ty(); - let field_ty = specialize_type(self.sa, field_ty, struct_type_params); - let register_ty = register_bty_from_ty(field_ty.clone()); - let idx = self.builder.add_const_struct_field( - StructId(struct_id.index().try_into().expect("overflow")), - bty_array_from_ty(struct_type_params), - idx as u32, - ); - let temp_reg = self.alloc_temp(register_ty); - self.builder.emit_load_struct_field(temp_reg, value, idx); - self.destruct_pattern_inner(pck, &field.pattern, temp_reg, field_ty); - self.free_temp(temp_reg); - }, - ) + iterate_subpatterns(self.analysis, pattern, |idx, field| { + let field_ty = struct_.fields[idx].ty(); + let field_ty = specialize_type(self.sa, field_ty, struct_type_params); + let register_ty = register_bty_from_ty(field_ty.clone()); + let idx = self.builder.add_const_struct_field( + StructId(struct_id.index().try_into().expect("overflow")), + bty_array_from_ty(struct_type_params), + idx as u32, + ); + let temp_reg = self.alloc_temp(register_ty); + self.builder.emit_load_struct_field(temp_reg, value, idx); + self.destruct_pattern_inner(pck, &field.pattern, temp_reg, field_ty); + self.free_temp(temp_reg); + }) } fn destruct_pattern_class( @@ -795,26 +781,21 @@ impl<'a> AstBytecodeGen<'a> { ) { let class = self.sa.class(class_id); - iterate_subpatterns( - self.analysis, - pattern, - class.fields.len(), - |idx, field_pattern| { - let field_ty = class.fields[idx].ty(); - let field_ty = specialize_type(self.sa, field_ty, class_type_params); - let register_ty = register_bty_from_ty(field_ty.clone()); - let idx = self.builder.add_const_field_types( - ClassId(class_id.index().try_into().expect("overflow")), - bty_array_from_ty(class_type_params), - idx as u32, - ); - let temp_reg = self.alloc_temp(register_ty); - self.builder - .emit_load_field(temp_reg, value, idx, self.loc(pattern.span())); - self.destruct_pattern_inner(pck, &field_pattern.pattern, temp_reg, field_ty); - self.free_temp(temp_reg); - }, - ) + iterate_subpatterns(self.analysis, pattern, |idx, field_pattern| { + let field_ty = class.fields[idx].ty(); + let field_ty = specialize_type(self.sa, field_ty, class_type_params); + let register_ty = register_bty_from_ty(field_ty.clone()); + let idx = self.builder.add_const_field_types( + ClassId(class_id.index().try_into().expect("overflow")), + bty_array_from_ty(class_type_params), + idx as u32, + ); + let temp_reg = self.alloc_temp(register_ty); + self.builder + .emit_load_field(temp_reg, value, idx, self.loc(pattern.span())); + self.destruct_pattern_inner(pck, &field_pattern.pattern, temp_reg, field_ty); + self.free_temp(temp_reg); + }) } fn destruct_pattern_tuple( @@ -831,23 +812,18 @@ impl<'a> AstBytecodeGen<'a> { } else { let tuple_subtypes = ty.tuple_subtypes().expect("tuple expected"); - iterate_subpatterns( - self.analysis, - pattern, - tuple_subtypes.len(), - |idx, field_pattern| { - let subtype = tuple_subtypes[idx].clone(); - let register_ty = register_bty_from_ty(subtype.clone()); - let cp_idx = self - .builder - .add_const_tuple_element(bty_from_ty(ty.clone()), idx as u32); - let temp_reg = self.alloc_temp(register_ty); - self.builder - .emit_load_tuple_element(temp_reg, value, cp_idx); - self.destruct_pattern_inner(pck, &field_pattern.pattern, temp_reg, subtype); - self.free_temp(temp_reg); - }, - ); + iterate_subpatterns(self.analysis, pattern, |idx, field_pattern| { + let subtype = tuple_subtypes[idx].clone(); + let register_ty = register_bty_from_ty(subtype.clone()); + let cp_idx = self + .builder + .add_const_tuple_element(bty_from_ty(ty.clone()), idx as u32); + let temp_reg = self.alloc_temp(register_ty); + self.builder + .emit_load_tuple_element(temp_reg, value, cp_idx); + self.destruct_pattern_inner(pck, &field_pattern.pattern, temp_reg, subtype); + self.free_temp(temp_reg); + }); } } @@ -3349,28 +3325,21 @@ fn get_subpatterns(p: &ast::PatternAlt) -> Option<&Vec>> } } -fn iterate_subpatterns(a: &AnalysisData, p: &ast::PatternAlt, def_length: usize, mut f: F) +fn iterate_subpatterns(a: &AnalysisData, p: &ast::PatternAlt, mut f: F) where F: FnMut(usize, &ast::PatternField), { if let Some(subpatterns) = get_subpatterns(p) { - let rest_len = if subpatterns.iter().find(|p| p.pattern.is_rest()).is_some() { - def_length - (subpatterns.len() - 1) - } else { - 0 - }; - - let mut idx = 0; - for subpattern in subpatterns { - if subpattern.pattern.is_rest() { - idx += rest_len; - } else if subpattern.pattern.is_underscore() { - idx += 1; + if subpattern.pattern.is_rest() || subpattern.pattern.is_underscore() { + // Do nothing. } else { - let field_id = a.map_field_ids.get(subpattern.id).cloned().unwrap_or(idx); + let field_id = a + .map_field_ids + .get(subpattern.id) + .cloned() + .expect("missing field_id"); f(field_id, subpattern.as_ref()); - idx += 1; } } } diff --git a/dora-frontend/src/typeck/stmt.rs b/dora-frontend/src/typeck/stmt.rs index a648890e4..a4d286b4f 100644 --- a/dora-frontend/src/typeck/stmt.rs +++ b/dora-frontend/src/typeck/stmt.rs @@ -557,50 +557,45 @@ fn check_subpatterns<'a>( let subpatterns = get_subpatterns(pattern); if let Some(subpatterns) = subpatterns { - let rest_count = subpatterns.iter().filter(|p| p.pattern.is_rest()).count(); - - if rest_count == 0 { - if subpatterns.len() != expected_types.len() { - let msg = ErrorMessage::PatternWrongNumberOfParams( - subpatterns.len(), - expected_types.len(), - ); - ck.sa.report(ck.file_id, pattern.span(), msg); - } + let mut idx = 0; + let mut rest_seen = false; + let mut pattern_count: usize = 0; - for (idx, subpattern) in subpatterns.iter().enumerate() { + for subpattern in subpatterns { + if subpattern.pattern.is_rest() { + if rest_seen { + let msg = ErrorMessage::PatternMultipleRest; + ck.sa.report(ck.file_id, subpattern.span, msg); + } else { + idx += expected_types + .len() + .checked_sub(subpatterns.len() - 1) + .unwrap_or(0); + rest_seen = true; + } + } else { let ty = expected_types.get(idx).cloned().unwrap_or(ty::error()); + ck.analysis.map_field_ids.insert(subpattern.id, idx); check_pattern_inner(ck, ctxt, &subpattern.pattern, ty); + idx += 1; + pattern_count += 1; } - } else if rest_count == 1 { - let pattern_count = subpatterns.len() - 1; + } + if rest_seen { if pattern_count > expected_types.len() { let msg = ErrorMessage::PatternWrongNumberOfParams(pattern_count, expected_types.len()); ck.sa.report(ck.file_id, pattern.span(), msg); - - check_subpatterns_error(ck, ctxt, pattern); - return; - } - - let rest_len = expected_types.len() - pattern_count; - let mut idx = 0; - - for subpattern in subpatterns { - if subpattern.pattern.is_rest() { - idx += rest_len; - } else { - let ty = expected_types.get(idx).cloned().unwrap_or(ty::error()); - check_pattern_inner(ck, ctxt, &subpattern.pattern, ty); - idx += 1; - } } } else { - let msg = ErrorMessage::PatternMultipleRest; - ck.sa.report(ck.file_id, pattern.span(), msg); - - check_subpatterns_error(ck, ctxt, pattern); + if expected_types.len() != pattern_count { + let msg = ErrorMessage::PatternWrongNumberOfParams( + subpatterns.len(), + expected_types.len(), + ); + ck.sa.report(ck.file_id, pattern.span(), msg); + } } } else { if expected_types.len() > 0 { diff --git a/dora-frontend/src/typeck/tests.rs b/dora-frontend/src/typeck/tests.rs index fc0382fad..4d5776418 100644 --- a/dora-frontend/src/typeck/tests.rs +++ b/dora-frontend/src/typeck/tests.rs @@ -4546,7 +4546,7 @@ fn test_pattern_rest() { let (.., a, ..) = x; } ", - (3, 17), + (3, 25), ErrorMessage::PatternMultipleRest, ); @@ -5163,3 +5163,49 @@ fn struct_named_pattern_rest_last() { ErrorMessage::PatternRestShouldBeLast, ); } + +#[test] +fn class_named_pattern() { + ok(" + class Foo { a: Int, b: Int } + fn f(x: Foo): Int { + let Foo(a = x, b = y) = x; + x + y + } + "); + + ok(" + class Foo { a: Int, b: Int } + fn f(x: Foo): Int { + let Foo(a, b) = x; + a + b + } + "); +} + +#[test] +fn enum_named_pattern() { + ok(" + enum Foo { + A, + B { a: Int, b: Int } + } + + fn f(x: Foo): Int { + let Foo::B(a = x, b = y) = x; + x + y + } + "); + + ok(" + enum Foo { + A, + B { a: Int, b: Int } + } + + fn f(x: Foo): Int { + let Foo::B(a, b) = x; + a + b + } + "); +} diff --git a/tests/struct/struct-named-pattern2.dora b/tests/struct/struct-named-pattern2.dora new file mode 100644 index 000000000..71c9ec23a --- /dev/null +++ b/tests/struct/struct-named-pattern2.dora @@ -0,0 +1,18 @@ +struct Foo { a: Int, b: Int, c: Int, d: Int } + +fn get_a1(x: Foo): Int { + let Foo(a, c, ..) = x; + a +} + +fn get_a2(x: Foo): Int { + let Foo(c, a, ..) = x; + a +} + +fn main() { + let x = Foo(a=10, b=23, c=47, d=91); + let a1 = get_a1(x); + let a2 = get_a2(x); + assert(a1 == 10 && a2 == 10); +}