From c1bd3d4242ddb821d427deceaa6f77477553002d Mon Sep 17 00:00:00 2001 From: Lai-YT <381xvmvbib@gmail.com> Date: Tue, 9 Jul 2024 14:19:43 +0800 Subject: [PATCH 1/2] Use dynamic cast on reference when should not fail There is a common pattern in our code where we dynamically cast a pointer and assert that it is not null. This pattern implies that the cast should not fail. Therefore, we can use the reference version of dynamic casting, which throws a `std::bad_cast` exception if the cast fails. Additionally, we are using dynamic casting with unique pointers, which requires a manual `get()` to obtain the underlying pointer. Since `get()` should only be used when necessary, this change also eliminates its use. --- parser.y | 35 +++++++++++++------------- src/llvm_ir_generator.cpp | 35 +++++++++++--------------- src/qbe_ir_generator.cpp | 53 ++++++++++++++++----------------------- src/type_checker.cpp | 17 ++++++------- 4 files changed, 61 insertions(+), 79 deletions(-) diff --git a/parser.y b/parser.y index b5a3637c..f351d19b 100644 --- a/parser.y +++ b/parser.y @@ -184,16 +184,16 @@ func_def: declaration_specifiers declarator compound_stmt { auto func_def = $2; assert(dynamic_cast(func_def.get())); assert(func_def->type->IsFunc()); - const auto* func_type = static_cast(func_def->type.get()); + const auto& func_type = dynamic_cast(*func_def->type); auto type = std::get>($1); - auto resolved_return_type = ResolveType(std::move(type), func_type->return_type().Clone()); + auto resolved_return_type = ResolveType(std::move(type), func_type.return_type().Clone()); auto param_types = std::vector>{}; - for (auto& param : func_type->param_types()) { + for (auto& param : func_type.param_types()) { param_types.push_back(param->Clone()); } func_def->type = std::make_unique(std::move(resolved_return_type), std::move(param_types)); - static_cast(func_def.get())->body = $3; - $$ = std::unique_ptr(static_cast(func_def.release())); + dynamic_cast(*func_def).body = $3; + $$ = std::unique_ptr(dynamic_cast(func_def.release())); } ; @@ -425,11 +425,11 @@ decl: declaration_specifiers init_declarator_list_opt SEMICOLON { decl_list.push_back(std::move(decl)); } - auto* rec_decl = dynamic_cast(decl.get()); + auto& rec_decl = dynamic_cast(*decl); // Initialize record variable. for (auto& init_decl : init_decl_list) { if (init_decl) { - init_decl->type = ResolveType(rec_decl->type->Clone(), std::move(init_decl->type)); + init_decl->type = ResolveType(rec_decl.type->Clone(), std::move(init_decl->type)); } decl_list.push_back(std::move(init_decl)); } @@ -469,10 +469,9 @@ init_declarator: declarator { $$ = $1; } auto decl = $1; auto init = $3; if (std::holds_alternative>(init)) { - auto* var_decl = dynamic_cast(decl.get()); - assert(var_decl); + auto& var_decl = dynamic_cast(*decl); auto initializer = std::move(std::get>(init)); - var_decl->init = std::move(initializer->expr); + var_decl.init = std::move(initializer->expr); } else { // The initializer is a list of expressions. auto init_expr_list = std::move(std::get>>(init)); if (auto* arr_decl = dynamic_cast(decl.get())) { @@ -801,21 +800,21 @@ std::unique_ptr ResolveType(std::unique_ptr resolved_type, } // Since we cannot change the internal state of a type, we construct a new one. if (unknown_type->IsPtr()) { - auto ptr_type = static_cast(unknown_type.get()); - resolved_type = ResolveType(std::move(resolved_type), ptr_type->base_type().Clone()); + auto& ptr_type = static_cast(*unknown_type); + resolved_type = ResolveType(std::move(resolved_type), ptr_type.base_type().Clone()); return std::make_unique(std::move(resolved_type)); } if (unknown_type->IsArr()) { - auto arr_type = static_cast(unknown_type.get()); - resolved_type = ResolveType(std::move(resolved_type), arr_type->element_type().Clone()); - return std::make_unique(std::move(resolved_type), arr_type->len()); + auto& arr_type = dynamic_cast(*unknown_type); + resolved_type = ResolveType(std::move(resolved_type), arr_type.element_type().Clone()); + return std::make_unique(std::move(resolved_type), arr_type.len()); } if (unknown_type->IsFunc()) { // NOTE: Due to the structure of the grammar, the return type of a function is to be resolved. - auto func_type = static_cast(unknown_type.get()); - resolved_type = ResolveType(std::move(resolved_type), func_type->return_type().Clone()); + auto& func_type = dynamic_cast(*unknown_type); + resolved_type = ResolveType(std::move(resolved_type), func_type.return_type().Clone()); auto param_types = std::vector>{}; - for (const auto& param : func_type->param_types()) { + for (const auto& param : func_type.param_types()) { param_types.push_back(param->Clone()); } return std::make_unique(std::move(resolved_type), std::move(param_types)); diff --git a/src/llvm_ir_generator.cpp b/src/llvm_ir_generator.cpp index c66db671..8c7f7444 100644 --- a/src/llvm_ir_generator.cpp +++ b/src/llvm_ir_generator.cpp @@ -198,10 +198,10 @@ void LLVMIRGenerator::Visit(const ArrDeclNode& arr_decl) { id_to_val[arr_decl.id] = addr; } - auto arr_decl_type = dynamic_cast(arr_decl.type.get()); + auto& arr_decl_type = dynamic_cast(*arr_decl.type); // This vector stores the initialize values for a global array. std::vector arr_elems{}; - for (auto i = std::size_t{0}, e = arr_decl_type->len(), + for (auto i = std::size_t{0}, e = arr_decl_type.len(), init_len = arr_decl.init_list.size(); i < e; ++i) { if (i < init_len) { @@ -250,8 +250,7 @@ void LLVMIRGenerator::Visit(const FieldNode& field) { } void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) { - auto* record_type = dynamic_cast(record_var_decl.type.get()); - assert(record_type); + auto& record_type = dynamic_cast(*record_var_decl.type); auto type = builder_helper_.GetLLVMType(*(record_var_decl.type)); auto base_addr = builder_.CreateAlloca(type); id_to_val[record_var_decl.id] = base_addr; @@ -260,7 +259,7 @@ void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) { // exceed the total number of members in a record. Also, it guarantees // that accessing element in the initializers will not go out of bound. for (auto i = std::size_t{0}, e = record_var_decl.inits.size(), - slot_count = record_type->SlotCount(); + slot_count = record_type.SlotCount(); i < slot_count && i < e; ++i) { auto& init = record_var_decl.inits.at(i); init->Accept(*this); @@ -513,10 +512,9 @@ void LLVMIRGenerator::Visit(const IdLabeledStmtNode& id_labeled_stmt) { void LLVMIRGenerator::Visit(const CaseStmtNode& case_stmt) { case_stmt.expr->Accept(*this); auto val = val_recorder.ValOfPrevExpr(); - auto int_expr = dynamic_cast(case_stmt.expr.get()); - assert(int_expr); + auto& int_expr = dynamic_cast(*case_stmt.expr); - auto case_label = "case_" + std::to_string(int_expr->val); + auto case_label = "case_" + std::to_string(int_expr.val); auto case_bb = llvm::BasicBlock::Create(context_, case_label, builder_helper_.CurrFunc()); @@ -586,11 +584,10 @@ void LLVMIRGenerator::Visit(const ArrSubExprNode& arr_sub_expr) { auto base_addr = val_to_id_addr.at(val); auto arr_type = builder_helper_.GetLLVMType(*(arr_sub_expr.arr->type)); arr_sub_expr.index->Accept(*this); - auto index = dynamic_cast(arr_sub_expr.index.get()); - assert(index); + auto& index = dynamic_cast(*arr_sub_expr.index); auto res_addr = builder_.CreateConstInBoundsGEP2_32(arr_type, base_addr, 0, - (unsigned int)index->val); + (unsigned int)index.val); auto res_val = builder_.CreateLoad(arr_type->getArrayElementType(), res_addr); val_to_id_addr[res_val] = res_addr; val_recorder.Record(res_val); @@ -694,9 +691,8 @@ void LLVMIRGenerator::Visit(const PostfixArithExprNode& postfix_expr) { auto one = llvm::ConstantInt::get(builder_.getInt32Ty(), 1, true); auto res = builder_.CreateBinOp(arith_op, val, one); - const auto* id_expr = dynamic_cast((postfix_expr.operand).get()); - assert(id_expr); - builder_.CreateStore(res, id_to_val.at(id_expr->id)); + const auto& id_expr = dynamic_cast(*postfix_expr.operand); + builder_.CreateStore(res, id_to_val.at(id_expr.id)); } void LLVMIRGenerator::Visit(const RecordMemExprNode& mem_expr) { @@ -704,13 +700,12 @@ void LLVMIRGenerator::Visit(const RecordMemExprNode& mem_expr) { auto val = val_recorder.ValOfPrevExpr(); auto base_addr = val_to_id_addr.at(val); auto llvm_type = builder_helper_.GetLLVMType(*(mem_expr.expr->type)); - auto* record_type = dynamic_cast(mem_expr.expr->type.get()); - assert(record_type); + auto& record_type = dynamic_cast(*mem_expr.expr->type); auto res_addr = builder_.CreateStructGEP( - llvm_type, base_addr, record_type->MemberIndex(mem_expr.id)); + llvm_type, base_addr, record_type.MemberIndex(mem_expr.id)); auto res_val = builder_.CreateLoad( - builder_helper_.GetLLVMType(record_type->MemberType(mem_expr.id)), + builder_helper_.GetLLVMType(record_type.MemberType(mem_expr.id)), res_addr); val_to_id_addr[res_val] = res_addr; val_recorder.Record(res_val); @@ -765,8 +760,8 @@ void LLVMIRGenerator::Visit(const UnaryExprNode& unary_expr) { case UnaryOperator::kDeref: { // Is function pointer. if (unary_expr.operand->type->IsPtr() && - dynamic_cast((unary_expr.operand->type).get()) - ->base_type() + dynamic_cast(*unary_expr.operand->type) + .base_type() .IsFunc()) { // No-op; the value is still the function itself. break; diff --git a/src/qbe_ir_generator.cpp b/src/qbe_ir_generator.cpp index 36434878..97a7ec2c 100644 --- a/src/qbe_ir_generator.cpp +++ b/src/qbe_ir_generator.cpp @@ -222,14 +222,13 @@ void QbeIrGenerator::Visit(const VarDeclNode& decl) { void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) { if (arr_decl.is_global) { - const auto* arr_type = dynamic_cast((arr_decl.type).get()); - assert(arr_type); + const auto& arr_type = dynamic_cast(*arr_decl.type); Write_("export data {} = align {} {{ ", user_defined::GlobalPointer{arr_decl.id}, - arr_type->element_type().size()); + arr_type.element_type().size()); global_var_init_vals.clear(); - auto arr_size = arr_type->len(); + auto arr_size = arr_type.len(); auto init_len = arr_decl.init_list.size(); assert(init_len <= arr_size); for (auto i = std::size_t{0}; i < init_len; ++i) { @@ -243,14 +242,14 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) { // set remaining elements as 0 if (init_len < arr_size) { - Write_("z {}", (arr_size - init_len) * arr_type->element_type().size()); + Write_("z {}", (arr_size - init_len) * arr_type.element_type().size()); } Write_(" }}\n"); } else { int base_addr_num = NextLocalNum(); assert(arr_decl.type->IsArr()); - const auto* arr_type = dynamic_cast((arr_decl.type).get()); - auto element_size = arr_type->element_type().size(); + const auto& arr_type = dynamic_cast(*arr_decl.type); + auto element_size = arr_type.element_type().size(); WriteInstr_("{} =l alloc{} {}", FuncScopeTemp{base_addr_num}, element_size, arr_decl.type->size()); id_to_num[arr_decl.id] = base_addr_num; @@ -259,7 +258,7 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) { // 6.7.9 Initialization // 10. If an object that has automatic storage duration is not initialized // explicitly, its value is indeterminate. - for (auto i = std::size_t{0}, e = arr_type->len(), + for (auto i = std::size_t{0}, e = arr_type.len(), init_len = arr_decl.init_list.size(); i < e && init_len != 0; ++i) { if (i < init_len) { @@ -302,13 +301,12 @@ void QbeIrGenerator::Visit(const RecordVarDeclNode& record_var_decl) { record_var_decl.type->size()); id_to_num[record_var_decl.id] = base_addr; - auto* record_type = dynamic_cast(record_var_decl.type.get()); - assert(record_type); + auto& record_type = dynamic_cast(*record_var_decl.type); // NOTE: This predicate will make sure that we don't initialize members that - // exceed the total number of members in a record. Also, it gurantees + // exceed the total number of members in a record. Also, it guarantees // that accessing element in the initializers will not go out of bound. for (auto i = std::size_t{0}, e = record_var_decl.inits.size(), - slot_count = record_type->SlotCount(); + slot_count = record_type.SlotCount(); i < slot_count && i < e; ++i) { const auto& init = record_var_decl.inits.at(i); init->Accept(*this); @@ -316,7 +314,7 @@ void QbeIrGenerator::Visit(const RecordVarDeclNode& record_var_decl) { // res_addr = base_addr + offset const int res_addr_num = NextLocalNum(); - const auto offset = record_type->OffsetOf(i); + const auto offset = record_type.OffsetOf(i); WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num}, FuncScopeTemp{base_addr}, offset); WriteInstr_("storew {}, {}", FuncScopeTemp{init_num}, @@ -780,10 +778,9 @@ void QbeIrGenerator::Visit(const ArrSubExprNode& arr_sub_expr) { // e.g. int a[3] // a[1]'s offset = 1 * 4 (int size) const int offset = NextLocalNum(); - const auto* arr_type = dynamic_cast((arr_sub_expr.arr->type).get()); - assert(arr_type); + const auto& arr_type = dynamic_cast(*arr_sub_expr.arr->type); WriteInstr_("{} =l mul {}, {}", FuncScopeTemp{offset}, - FuncScopeTemp{extended_num}, arr_type->element_type().size()); + FuncScopeTemp{extended_num}, arr_type.element_type().size()); // res_addr = base_addr + offset const int res_addr_num = NextLocalNum(); @@ -896,22 +893,20 @@ void QbeIrGenerator::Visit(const PostfixArithExprNode& postfix_expr) { // TODO: support pointer arithmetic WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num}, GetBinaryOperator(arith_op), FuncScopeTemp{expr_num}); - const auto* id_expr = dynamic_cast((postfix_expr.operand).get()); - assert(id_expr); + const auto& id_expr = dynamic_cast(*postfix_expr.operand); WriteInstr_("storew {}, {}", FuncScopeTemp{res_num}, - FuncScopeTemp{id_to_num.at(id_expr->id)}); + FuncScopeTemp{id_to_num.at(id_expr.id)}); } void QbeIrGenerator::Visit(const RecordMemExprNode& mem_expr) { mem_expr.expr->Accept(*this); const auto num = num_recorder.NumOfPrevExpr(); const auto id_num = reg_num_to_id_num.at(num); - auto* record_type = dynamic_cast(mem_expr.expr->type.get()); - assert(record_type); + auto& record_type = dynamic_cast(*mem_expr.expr->type); const auto res_addr_num = NextLocalNum(); WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num}, - FuncScopeTemp{id_num}, record_type->OffsetOf(mem_expr.id)); + FuncScopeTemp{id_num}, record_type.OffsetOf(mem_expr.id)); const int res_num = NextLocalNum(); WriteInstr_("{} =w loadw {}", FuncScopeTemp{res_num}, @@ -933,11 +928,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) { : BinaryOperator::kSub; WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num}, GetBinaryOperator(arith_op), FuncScopeTemp{expr_num}); - const auto* id_expr = - dynamic_cast((unary_expr.operand).get()); - assert(id_expr); + const auto& id_expr = dynamic_cast(*unary_expr.operand); WriteInstr_("storew {}, {}", FuncScopeTemp{res_num}, - FuncScopeTemp{id_to_num.at(id_expr->id)}); + FuncScopeTemp{id_to_num.at(id_expr.id)}); num_recorder.Record(res_num); } break; case UnaryOperator::kPos: @@ -974,11 +967,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) { // No-op; the function itself already evaluates to the address. break; } - const auto* id_expr = - dynamic_cast((unary_expr.operand).get()); // NOTE: The operand of the address-of operator must be an lvalue, and we // do not support arrays now, so it must have been backed by an id. - assert(id_expr); + assert(dynamic_cast(unary_expr.operand.get())); // The address of the id is the id itself. const int reg_num = num_recorder.NumOfPrevExpr(); const int id_num = reg_num_to_id_num.at(reg_num); @@ -993,8 +984,8 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) { case UnaryOperator::kDeref: { // Is function pointer. if (unary_expr.operand->type->IsPtr() && - dynamic_cast((unary_expr.operand->type).get()) - ->base_type() + dynamic_cast(*unary_expr.operand->type) + .base_type() .IsFunc()) { // No-op; the function itself also evaluates to the address. break; diff --git a/src/type_checker.cpp b/src/type_checker.cpp index 5f55a220..59bf98a2 100644 --- a/src/type_checker.cpp +++ b/src/type_checker.cpp @@ -157,8 +157,7 @@ void TypeChecker::Visit(RecordVarDeclNode& record_var_decl) { // struct birth bd1 { .date = 1 }; // RecordVarDeclNode -> search type entry // to update its type. // record_type_id is "struct_birth" in the above example. - auto record_type_id = - dynamic_cast(record_var_decl.type.get())->id(); + auto record_type_id = dynamic_cast(*record_var_decl.type).id(); auto record_type = env_.LookUpType( MangleRecordTypeId(record_type_id, record_var_decl.type)); assert(record_type); @@ -191,7 +190,7 @@ void TypeChecker::Visit(ParamNode& parameter) { if (parameter.type->IsArr()) { // Decay to simple pointer type. parameter.type = std::make_unique( - dynamic_cast(parameter.type.get())->element_type().Clone()); + dynamic_cast(*parameter.type).element_type().Clone()); } else if (parameter.type->IsFunc()) { // Decay to function pointer type. parameter.type = std::make_unique(parameter.type->Clone()); @@ -233,7 +232,7 @@ void TypeChecker::Visit(FuncDefNode& func_def) { decayed_param_types.push_back(parameter->type->Clone()); } auto return_type = - dynamic_cast(func_def.type.get())->return_type().Clone(); + dynamic_cast(*func_def.type).return_type().Clone(); func_def.type = std::make_unique(std::move(return_type), std::move(decayed_param_types)); auto symbol = @@ -461,10 +460,9 @@ void TypeChecker::Visit(ArgExprNode& arg_expr) { void TypeChecker::Visit(ArrSubExprNode& arr_sub_expr) { arr_sub_expr.arr->Accept(*this); arr_sub_expr.index->Accept(*this); - const auto* arr_type = dynamic_cast((arr_sub_expr.arr->type).get()); - assert(arr_type); + const auto& arr_type = dynamic_cast(*arr_sub_expr.arr->type); // arr_sub_expr should have the element type of the array. - arr_sub_expr.type = arr_type->element_type().Clone(); + arr_sub_expr.type = arr_type.element_type().Clone(); arr_sub_expr.is_global = arr_sub_expr.arr->is_global; } @@ -565,9 +563,8 @@ void TypeChecker::Visit(UnaryExprNode& unary_expr) { if (!unary_expr.operand->type->IsPtr()) { // TODO: the operand of unary '*' shall have pointer type } - unary_expr.type = dynamic_cast(unary_expr.operand->type.get()) - ->base_type() - .Clone(); + unary_expr.type = + dynamic_cast(*unary_expr.operand->type).base_type().Clone(); break; default: unary_expr.type = unary_expr.operand->type->Clone(); From 5943f743b6b8a224ca0994db7037dfc378e77e4e Mon Sep 17 00:00:00 2001 From: Lai-YT <381xvmvbib@gmail.com> Date: Wed, 10 Jul 2024 14:10:37 +0800 Subject: [PATCH 2/2] Fix use-after-move --- parser.y | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/parser.y b/parser.y index f351d19b..75ea12f6 100644 --- a/parser.y +++ b/parser.y @@ -423,15 +423,15 @@ decl: declaration_specifiers init_declarator_list_opt SEMICOLON { // A record declaration that doesn't declare any identifier, e.g., `struct point {int x, int y};`. if (init_decl_list.empty()) { decl_list.push_back(std::move(decl)); - } - - auto& rec_decl = dynamic_cast(*decl); - // Initialize record variable. - for (auto& init_decl : init_decl_list) { - if (init_decl) { - init_decl->type = ResolveType(rec_decl.type->Clone(), std::move(init_decl->type)); + } else { + auto& rec_decl = dynamic_cast(*decl); + // Initialize record variable. + for (auto& init_decl : init_decl_list) { + if (init_decl) { + init_decl->type = ResolveType(rec_decl.type->Clone(), std::move(init_decl->type)); + } + decl_list.push_back(std::move(init_decl)); } - decl_list.push_back(std::move(init_decl)); } } $$ = std::make_unique(Loc(@1), std::move(decl_list));