Skip to content

Commit

Permalink
Use dynamic cast on reference when should not fail
Browse files Browse the repository at this point in the history
There is a common pattern in our code where we dynamically cast a
pointer and assert that it is not null. This pattern implies that the
cast should not fail. Therefore, we can use the reference version of
dynamic casting, which throws a `std::bad_cast` exception if the cast
fails.

Additionally, we are using dynamic casting with unique pointers, which
requires a manual `get()` to obtain the underlying pointer. Since
`get()` should only be used when necessary, this change also eliminates
its use.
  • Loading branch information
Lai-YT authored and leewei05 committed Jul 12, 2024
1 parent 2152c9c commit 683cf36
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 79 deletions.
35 changes: 17 additions & 18 deletions parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,16 @@ func_def: declaration_specifiers declarator compound_stmt {
auto func_def = $2;
assert(dynamic_cast<FuncDefNode*>(func_def.get()));
assert(func_def->type->IsFunc());
const auto* func_type = static_cast<FuncType*>(func_def->type.get());
const auto& func_type = dynamic_cast<FuncType&>(*func_def->type);
auto type = std::get<std::unique_ptr<Type>>($1);
auto resolved_return_type = ResolveType(std::move(type), func_type->return_type().Clone());
auto resolved_return_type = ResolveType(std::move(type), func_type.return_type().Clone());
auto param_types = std::vector<std::unique_ptr<Type>>{};
for (auto& param : func_type->param_types()) {
for (auto& param : func_type.param_types()) {
param_types.push_back(param->Clone());
}
func_def->type = std::make_unique<FuncType>(std::move(resolved_return_type), std::move(param_types));
static_cast<FuncDefNode*>(func_def.get())->body = $3;
$$ = std::unique_ptr<FuncDefNode>(static_cast<FuncDefNode*>(func_def.release()));
dynamic_cast<FuncDefNode&>(*func_def).body = $3;
$$ = std::unique_ptr<FuncDefNode>(dynamic_cast<FuncDefNode*>(func_def.release()));
}
;

Expand Down Expand Up @@ -425,11 +425,11 @@ decl: declaration_specifiers init_declarator_list_opt SEMICOLON {
decl_list.push_back(std::move(decl));
}

auto* rec_decl = dynamic_cast<RecordDeclNode*>(decl.get());
auto& rec_decl = dynamic_cast<RecordDeclNode&>(*decl);
// Initialize record variable.
for (auto& init_decl : init_decl_list) {
if (init_decl) {
init_decl->type = ResolveType(rec_decl->type->Clone(), std::move(init_decl->type));
init_decl->type = ResolveType(rec_decl.type->Clone(), std::move(init_decl->type));
}
decl_list.push_back(std::move(init_decl));
}
Expand Down Expand Up @@ -469,10 +469,9 @@ init_declarator: declarator { $$ = $1; }
auto decl = $1;
auto init = $3;
if (std::holds_alternative<std::unique_ptr<InitExprNode>>(init)) {
auto* var_decl = dynamic_cast<VarDeclNode*>(decl.get());
assert(var_decl);
auto& var_decl = dynamic_cast<VarDeclNode&>(*decl);
auto initializer = std::move(std::get<std::unique_ptr<InitExprNode>>(init));
var_decl->init = std::move(initializer->expr);
var_decl.init = std::move(initializer->expr);
} else { // The initializer is a list of expressions.
auto init_expr_list = std::move(std::get<std::vector<std::unique_ptr<InitExprNode>>>(init));
if (auto* arr_decl = dynamic_cast<ArrDeclNode*>(decl.get())) {
Expand Down Expand Up @@ -801,21 +800,21 @@ std::unique_ptr<Type> ResolveType(std::unique_ptr<Type> resolved_type,
}
// Since we cannot change the internal state of a type, we construct a new one.
if (unknown_type->IsPtr()) {
auto ptr_type = static_cast<PtrType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), ptr_type->base_type().Clone());
auto& ptr_type = static_cast<PtrType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), ptr_type.base_type().Clone());
return std::make_unique<PtrType>(std::move(resolved_type));
}
if (unknown_type->IsArr()) {
auto arr_type = static_cast<ArrType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), arr_type->element_type().Clone());
return std::make_unique<ArrType>(std::move(resolved_type), arr_type->len());
auto& arr_type = dynamic_cast<ArrType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), arr_type.element_type().Clone());
return std::make_unique<ArrType>(std::move(resolved_type), arr_type.len());
}
if (unknown_type->IsFunc()) {
// NOTE: Due to the structure of the grammar, the return type of a function is to be resolved.
auto func_type = static_cast<FuncType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), func_type->return_type().Clone());
auto& func_type = dynamic_cast<FuncType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), func_type.return_type().Clone());
auto param_types = std::vector<std::unique_ptr<Type>>{};
for (const auto& param : func_type->param_types()) {
for (const auto& param : func_type.param_types()) {
param_types.push_back(param->Clone());
}
return std::make_unique<FuncType>(std::move(resolved_type), std::move(param_types));
Expand Down
35 changes: 15 additions & 20 deletions src/llvm_ir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ void LLVMIRGenerator::Visit(const ArrDeclNode& arr_decl) {
id_to_val[arr_decl.id] = addr;
}

auto arr_decl_type = dynamic_cast<ArrType*>(arr_decl.type.get());
auto& arr_decl_type = dynamic_cast<ArrType&>(*arr_decl.type);
// This vector stores the initialize values for a global array.
std::vector<llvm::Constant*> arr_elems{};
for (auto i = std::size_t{0}, e = arr_decl_type->len(),
for (auto i = std::size_t{0}, e = arr_decl_type.len(),
init_len = arr_decl.init_list.size();
i < e; ++i) {
if (i < init_len) {
Expand Down Expand Up @@ -250,8 +250,7 @@ void LLVMIRGenerator::Visit(const FieldNode& field) {
}

void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
auto* record_type = dynamic_cast<RecordType*>(record_var_decl.type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*record_var_decl.type);
auto type = builder_helper_.GetLLVMType(*(record_var_decl.type));
auto base_addr = builder_.CreateAlloca(type);
id_to_val[record_var_decl.id] = base_addr;
Expand All @@ -260,7 +259,7 @@ void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
// exceed the total number of members in a record. Also, it guarantees
// that accessing element in the initializers will not go out of bound.
for (auto i = std::size_t{0}, e = record_var_decl.inits.size(),
slot_count = record_type->SlotCount();
slot_count = record_type.SlotCount();
i < slot_count && i < e; ++i) {
auto& init = record_var_decl.inits.at(i);
init->Accept(*this);
Expand Down Expand Up @@ -513,10 +512,9 @@ void LLVMIRGenerator::Visit(const IdLabeledStmtNode& id_labeled_stmt) {
void LLVMIRGenerator::Visit(const CaseStmtNode& case_stmt) {
case_stmt.expr->Accept(*this);
auto val = val_recorder.ValOfPrevExpr();
auto int_expr = dynamic_cast<IntConstExprNode*>(case_stmt.expr.get());
assert(int_expr);
auto& int_expr = dynamic_cast<IntConstExprNode&>(*case_stmt.expr);

auto case_label = "case_" + std::to_string(int_expr->val);
auto case_label = "case_" + std::to_string(int_expr.val);
auto case_bb = llvm::BasicBlock::Create(context_, case_label,
builder_helper_.CurrFunc());

Expand Down Expand Up @@ -586,11 +584,10 @@ void LLVMIRGenerator::Visit(const ArrSubExprNode& arr_sub_expr) {
auto base_addr = val_to_id_addr.at(val);
auto arr_type = builder_helper_.GetLLVMType(*(arr_sub_expr.arr->type));
arr_sub_expr.index->Accept(*this);
auto index = dynamic_cast<IntConstExprNode*>(arr_sub_expr.index.get());
assert(index);
auto& index = dynamic_cast<IntConstExprNode&>(*arr_sub_expr.index);

auto res_addr = builder_.CreateConstInBoundsGEP2_32(arr_type, base_addr, 0,
(unsigned int)index->val);
(unsigned int)index.val);
auto res_val = builder_.CreateLoad(arr_type->getArrayElementType(), res_addr);
val_to_id_addr[res_val] = res_addr;
val_recorder.Record(res_val);
Expand Down Expand Up @@ -694,23 +691,21 @@ void LLVMIRGenerator::Visit(const PostfixArithExprNode& postfix_expr) {

auto one = llvm::ConstantInt::get(builder_.getInt32Ty(), 1, true);
auto res = builder_.CreateBinOp(arith_op, val, one);
const auto* id_expr = dynamic_cast<IdExprNode*>((postfix_expr.operand).get());
assert(id_expr);
builder_.CreateStore(res, id_to_val.at(id_expr->id));
const auto& id_expr = dynamic_cast<IdExprNode&>(*postfix_expr.operand);
builder_.CreateStore(res, id_to_val.at(id_expr.id));
}

void LLVMIRGenerator::Visit(const RecordMemExprNode& mem_expr) {
mem_expr.expr->Accept(*this);
auto val = val_recorder.ValOfPrevExpr();
auto base_addr = val_to_id_addr.at(val);
auto llvm_type = builder_helper_.GetLLVMType(*(mem_expr.expr->type));
auto* record_type = dynamic_cast<RecordType*>(mem_expr.expr->type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*mem_expr.expr->type);

auto res_addr = builder_.CreateStructGEP(
llvm_type, base_addr, record_type->MemberIndex(mem_expr.id));
llvm_type, base_addr, record_type.MemberIndex(mem_expr.id));
auto res_val = builder_.CreateLoad(
builder_helper_.GetLLVMType(record_type->MemberType(mem_expr.id)),
builder_helper_.GetLLVMType(record_type.MemberType(mem_expr.id)),
res_addr);
val_to_id_addr[res_val] = res_addr;
val_recorder.Record(res_val);
Expand Down Expand Up @@ -765,8 +760,8 @@ void LLVMIRGenerator::Visit(const UnaryExprNode& unary_expr) {
case UnaryOperator::kDeref: {
// Is function pointer.
if (unary_expr.operand->type->IsPtr() &&
dynamic_cast<PtrType*>((unary_expr.operand->type).get())
->base_type()
dynamic_cast<PtrType&>(*unary_expr.operand->type)
.base_type()
.IsFunc()) {
// No-op; the value is still the function itself.
break;
Expand Down
53 changes: 22 additions & 31 deletions src/qbe_ir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,13 @@ void QbeIrGenerator::Visit(const VarDeclNode& decl) {

void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {
if (arr_decl.is_global) {
const auto* arr_type = dynamic_cast<ArrType*>((arr_decl.type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_decl.type);
Write_("export data {} = align {} {{ ",
user_defined::GlobalPointer{arr_decl.id},
arr_type->element_type().size());
arr_type.element_type().size());

global_var_init_vals.clear();
auto arr_size = arr_type->len();
auto arr_size = arr_type.len();
auto init_len = arr_decl.init_list.size();
assert(init_len <= arr_size);
for (auto i = std::size_t{0}; i < init_len; ++i) {
Expand All @@ -243,14 +242,14 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {

// set remaining elements as 0
if (init_len < arr_size) {
Write_("z {}", (arr_size - init_len) * arr_type->element_type().size());
Write_("z {}", (arr_size - init_len) * arr_type.element_type().size());
}
Write_(" }}\n");
} else {
int base_addr_num = NextLocalNum();
assert(arr_decl.type->IsArr());
const auto* arr_type = dynamic_cast<ArrType*>((arr_decl.type).get());
auto element_size = arr_type->element_type().size();
const auto& arr_type = dynamic_cast<ArrType&>(*arr_decl.type);
auto element_size = arr_type.element_type().size();
WriteInstr_("{} =l alloc{} {}", FuncScopeTemp{base_addr_num}, element_size,
arr_decl.type->size());
id_to_num[arr_decl.id] = base_addr_num;
Expand All @@ -259,7 +258,7 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {
// 6.7.9 Initialization
// 10. If an object that has automatic storage duration is not initialized
// explicitly, its value is indeterminate.
for (auto i = std::size_t{0}, e = arr_type->len(),
for (auto i = std::size_t{0}, e = arr_type.len(),
init_len = arr_decl.init_list.size();
i < e && init_len != 0; ++i) {
if (i < init_len) {
Expand Down Expand Up @@ -302,21 +301,20 @@ void QbeIrGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
record_var_decl.type->size());
id_to_num[record_var_decl.id] = base_addr;

auto* record_type = dynamic_cast<RecordType*>(record_var_decl.type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*record_var_decl.type);
// NOTE: This predicate will make sure that we don't initialize members that
// exceed the total number of members in a record. Also, it gurantees
// exceed the total number of members in a record. Also, it guarantees
// that accessing element in the initializers will not go out of bound.
for (auto i = std::size_t{0}, e = record_var_decl.inits.size(),
slot_count = record_type->SlotCount();
slot_count = record_type.SlotCount();
i < slot_count && i < e; ++i) {
const auto& init = record_var_decl.inits.at(i);
init->Accept(*this);
const auto init_num = num_recorder.NumOfPrevExpr();

// res_addr = base_addr + offset
const int res_addr_num = NextLocalNum();
const auto offset = record_type->OffsetOf(i);
const auto offset = record_type.OffsetOf(i);
WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num},
FuncScopeTemp{base_addr}, offset);
WriteInstr_("storew {}, {}", FuncScopeTemp{init_num},
Expand Down Expand Up @@ -780,10 +778,9 @@ void QbeIrGenerator::Visit(const ArrSubExprNode& arr_sub_expr) {
// e.g. int a[3]
// a[1]'s offset = 1 * 4 (int size)
const int offset = NextLocalNum();
const auto* arr_type = dynamic_cast<ArrType*>((arr_sub_expr.arr->type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_sub_expr.arr->type);
WriteInstr_("{} =l mul {}, {}", FuncScopeTemp{offset},
FuncScopeTemp{extended_num}, arr_type->element_type().size());
FuncScopeTemp{extended_num}, arr_type.element_type().size());

// res_addr = base_addr + offset
const int res_addr_num = NextLocalNum();
Expand Down Expand Up @@ -896,22 +893,20 @@ void QbeIrGenerator::Visit(const PostfixArithExprNode& postfix_expr) {
// TODO: support pointer arithmetic
WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num},
GetBinaryOperator(arith_op), FuncScopeTemp{expr_num});
const auto* id_expr = dynamic_cast<IdExprNode*>((postfix_expr.operand).get());
assert(id_expr);
const auto& id_expr = dynamic_cast<IdExprNode&>(*postfix_expr.operand);
WriteInstr_("storew {}, {}", FuncScopeTemp{res_num},
FuncScopeTemp{id_to_num.at(id_expr->id)});
FuncScopeTemp{id_to_num.at(id_expr.id)});
}

void QbeIrGenerator::Visit(const RecordMemExprNode& mem_expr) {
mem_expr.expr->Accept(*this);
const auto num = num_recorder.NumOfPrevExpr();
const auto id_num = reg_num_to_id_num.at(num);
auto* record_type = dynamic_cast<RecordType*>(mem_expr.expr->type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*mem_expr.expr->type);

const auto res_addr_num = NextLocalNum();
WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num},
FuncScopeTemp{id_num}, record_type->OffsetOf(mem_expr.id));
FuncScopeTemp{id_num}, record_type.OffsetOf(mem_expr.id));

const int res_num = NextLocalNum();
WriteInstr_("{} =w loadw {}", FuncScopeTemp{res_num},
Expand All @@ -933,11 +928,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
: BinaryOperator::kSub;
WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num},
GetBinaryOperator(arith_op), FuncScopeTemp{expr_num});
const auto* id_expr =
dynamic_cast<IdExprNode*>((unary_expr.operand).get());
assert(id_expr);
const auto& id_expr = dynamic_cast<IdExprNode&>(*unary_expr.operand);
WriteInstr_("storew {}, {}", FuncScopeTemp{res_num},
FuncScopeTemp{id_to_num.at(id_expr->id)});
FuncScopeTemp{id_to_num.at(id_expr.id)});
num_recorder.Record(res_num);
} break;
case UnaryOperator::kPos:
Expand Down Expand Up @@ -974,11 +967,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
// No-op; the function itself already evaluates to the address.
break;
}
const auto* id_expr =
dynamic_cast<IdExprNode*>((unary_expr.operand).get());
// NOTE: The operand of the address-of operator must be an lvalue, and we
// do not support arrays now, so it must have been backed by an id.
assert(id_expr);
assert(dynamic_cast<IdExprNode*>(unary_expr.operand.get()));
// The address of the id is the id itself.
const int reg_num = num_recorder.NumOfPrevExpr();
const int id_num = reg_num_to_id_num.at(reg_num);
Expand All @@ -993,8 +984,8 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
case UnaryOperator::kDeref: {
// Is function pointer.
if (unary_expr.operand->type->IsPtr() &&
dynamic_cast<PtrType*>((unary_expr.operand->type).get())
->base_type()
dynamic_cast<PtrType&>(*unary_expr.operand->type)
.base_type()
.IsFunc()) {
// No-op; the function itself also evaluates to the address.
break;
Expand Down
17 changes: 7 additions & 10 deletions src/type_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ void TypeChecker::Visit(RecordVarDeclNode& record_var_decl) {
// struct birth bd1 { .date = 1 }; // RecordVarDeclNode -> search type entry
// to update its type.
// record_type_id is "struct_birth" in the above example.
auto record_type_id =
dynamic_cast<RecordType*>(record_var_decl.type.get())->id();
auto record_type_id = dynamic_cast<RecordType&>(*record_var_decl.type).id();
auto record_type = env_.LookUpType(
MangleRecordTypeId(record_type_id, record_var_decl.type));
assert(record_type);
Expand Down Expand Up @@ -191,7 +190,7 @@ void TypeChecker::Visit(ParamNode& parameter) {
if (parameter.type->IsArr()) {
// Decay to simple pointer type.
parameter.type = std::make_unique<PtrType>(
dynamic_cast<ArrType*>(parameter.type.get())->element_type().Clone());
dynamic_cast<ArrType&>(*parameter.type).element_type().Clone());
} else if (parameter.type->IsFunc()) {
// Decay to function pointer type.
parameter.type = std::make_unique<PtrType>(parameter.type->Clone());
Expand Down Expand Up @@ -233,7 +232,7 @@ void TypeChecker::Visit(FuncDefNode& func_def) {
decayed_param_types.push_back(parameter->type->Clone());
}
auto return_type =
dynamic_cast<FuncType*>(func_def.type.get())->return_type().Clone();
dynamic_cast<FuncType&>(*func_def.type).return_type().Clone();
func_def.type = std::make_unique<FuncType>(std::move(return_type),
std::move(decayed_param_types));
auto symbol =
Expand Down Expand Up @@ -461,10 +460,9 @@ void TypeChecker::Visit(ArgExprNode& arg_expr) {
void TypeChecker::Visit(ArrSubExprNode& arr_sub_expr) {
arr_sub_expr.arr->Accept(*this);
arr_sub_expr.index->Accept(*this);
const auto* arr_type = dynamic_cast<ArrType*>((arr_sub_expr.arr->type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_sub_expr.arr->type);
// arr_sub_expr should have the element type of the array.
arr_sub_expr.type = arr_type->element_type().Clone();
arr_sub_expr.type = arr_type.element_type().Clone();
arr_sub_expr.is_global = arr_sub_expr.arr->is_global;
}

Expand Down Expand Up @@ -565,9 +563,8 @@ void TypeChecker::Visit(UnaryExprNode& unary_expr) {
if (!unary_expr.operand->type->IsPtr()) {
// TODO: the operand of unary '*' shall have pointer type
}
unary_expr.type = dynamic_cast<PtrType*>(unary_expr.operand->type.get())
->base_type()
.Clone();
unary_expr.type =
dynamic_cast<PtrType&>(*unary_expr.operand->type).base_type().Clone();
break;
default:
unary_expr.type = unary_expr.operand->type->Clone();
Expand Down

0 comments on commit 683cf36

Please sign in to comment.