Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use dynamic cast on reference when should not fail #176

Merged
merged 2 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 23 additions & 24 deletions parser.y
Original file line number Diff line number Diff line change
Expand Up @@ -184,16 +184,16 @@ func_def: declaration_specifiers declarator compound_stmt {
auto func_def = $2;
assert(dynamic_cast<FuncDefNode*>(func_def.get()));
assert(func_def->type->IsFunc());
const auto* func_type = static_cast<FuncType*>(func_def->type.get());
const auto& func_type = dynamic_cast<FuncType&>(*func_def->type);
auto type = std::get<std::unique_ptr<Type>>($1);
auto resolved_return_type = ResolveType(std::move(type), func_type->return_type().Clone());
auto resolved_return_type = ResolveType(std::move(type), func_type.return_type().Clone());
auto param_types = std::vector<std::unique_ptr<Type>>{};
for (auto& param : func_type->param_types()) {
for (auto& param : func_type.param_types()) {
param_types.push_back(param->Clone());
}
func_def->type = std::make_unique<FuncType>(std::move(resolved_return_type), std::move(param_types));
static_cast<FuncDefNode*>(func_def.get())->body = $3;
$$ = std::unique_ptr<FuncDefNode>(static_cast<FuncDefNode*>(func_def.release()));
dynamic_cast<FuncDefNode&>(*func_def).body = $3;
$$ = std::unique_ptr<FuncDefNode>(dynamic_cast<FuncDefNode*>(func_def.release()));
}
;

Expand Down Expand Up @@ -423,15 +423,15 @@ decl: declaration_specifiers init_declarator_list_opt SEMICOLON {
// A record declaration that doesn't declare any identifier, e.g., `struct point {int x, int y};`.
if (init_decl_list.empty()) {
decl_list.push_back(std::move(decl));
}

auto* rec_decl = dynamic_cast<RecordDeclNode*>(decl.get());
// Initialize record variable.
for (auto& init_decl : init_decl_list) {
if (init_decl) {
init_decl->type = ResolveType(rec_decl->type->Clone(), std::move(init_decl->type));
} else {
auto& rec_decl = dynamic_cast<RecordDeclNode&>(*decl);
// Initialize record variable.
for (auto& init_decl : init_decl_list) {
if (init_decl) {
init_decl->type = ResolveType(rec_decl.type->Clone(), std::move(init_decl->type));
}
decl_list.push_back(std::move(init_decl));
}
decl_list.push_back(std::move(init_decl));
}
}
$$ = std::make_unique<DeclStmtNode>(Loc(@1), std::move(decl_list));
Expand Down Expand Up @@ -469,10 +469,9 @@ init_declarator: declarator { $$ = $1; }
auto decl = $1;
auto init = $3;
if (std::holds_alternative<std::unique_ptr<InitExprNode>>(init)) {
auto* var_decl = dynamic_cast<VarDeclNode*>(decl.get());
assert(var_decl);
auto& var_decl = dynamic_cast<VarDeclNode&>(*decl);
auto initializer = std::move(std::get<std::unique_ptr<InitExprNode>>(init));
var_decl->init = std::move(initializer->expr);
var_decl.init = std::move(initializer->expr);
} else { // The initializer is a list of expressions.
auto init_expr_list = std::move(std::get<std::vector<std::unique_ptr<InitExprNode>>>(init));
if (auto* arr_decl = dynamic_cast<ArrDeclNode*>(decl.get())) {
Expand Down Expand Up @@ -801,21 +800,21 @@ std::unique_ptr<Type> ResolveType(std::unique_ptr<Type> resolved_type,
}
// Since we cannot change the internal state of a type, we construct a new one.
if (unknown_type->IsPtr()) {
auto ptr_type = static_cast<PtrType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), ptr_type->base_type().Clone());
auto& ptr_type = static_cast<PtrType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), ptr_type.base_type().Clone());
return std::make_unique<PtrType>(std::move(resolved_type));
}
if (unknown_type->IsArr()) {
auto arr_type = static_cast<ArrType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), arr_type->element_type().Clone());
return std::make_unique<ArrType>(std::move(resolved_type), arr_type->len());
auto& arr_type = dynamic_cast<ArrType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), arr_type.element_type().Clone());
return std::make_unique<ArrType>(std::move(resolved_type), arr_type.len());
}
if (unknown_type->IsFunc()) {
// NOTE: Due to the structure of the grammar, the return type of a function is to be resolved.
auto func_type = static_cast<FuncType*>(unknown_type.get());
resolved_type = ResolveType(std::move(resolved_type), func_type->return_type().Clone());
auto& func_type = dynamic_cast<FuncType&>(*unknown_type);
resolved_type = ResolveType(std::move(resolved_type), func_type.return_type().Clone());
auto param_types = std::vector<std::unique_ptr<Type>>{};
for (const auto& param : func_type->param_types()) {
for (const auto& param : func_type.param_types()) {
param_types.push_back(param->Clone());
}
return std::make_unique<FuncType>(std::move(resolved_type), std::move(param_types));
Expand Down
35 changes: 15 additions & 20 deletions src/llvm_ir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,10 +198,10 @@ void LLVMIRGenerator::Visit(const ArrDeclNode& arr_decl) {
id_to_val[arr_decl.id] = addr;
}

auto arr_decl_type = dynamic_cast<ArrType*>(arr_decl.type.get());
auto& arr_decl_type = dynamic_cast<ArrType&>(*arr_decl.type);
// This vector stores the initialize values for a global array.
std::vector<llvm::Constant*> arr_elems{};
for (auto i = std::size_t{0}, e = arr_decl_type->len(),
for (auto i = std::size_t{0}, e = arr_decl_type.len(),
init_len = arr_decl.init_list.size();
i < e; ++i) {
if (i < init_len) {
Expand Down Expand Up @@ -250,8 +250,7 @@ void LLVMIRGenerator::Visit(const FieldNode& field) {
}

void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
auto* record_type = dynamic_cast<RecordType*>(record_var_decl.type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*record_var_decl.type);
auto type = builder_helper_.GetLLVMType(*(record_var_decl.type));
auto base_addr = builder_.CreateAlloca(type);
id_to_val[record_var_decl.id] = base_addr;
Expand All @@ -260,7 +259,7 @@ void LLVMIRGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
// exceed the total number of members in a record. Also, it guarantees
// that accessing element in the initializers will not go out of bound.
for (auto i = std::size_t{0}, e = record_var_decl.inits.size(),
slot_count = record_type->SlotCount();
slot_count = record_type.SlotCount();
i < slot_count && i < e; ++i) {
auto& init = record_var_decl.inits.at(i);
init->Accept(*this);
Expand Down Expand Up @@ -513,10 +512,9 @@ void LLVMIRGenerator::Visit(const IdLabeledStmtNode& id_labeled_stmt) {
void LLVMIRGenerator::Visit(const CaseStmtNode& case_stmt) {
case_stmt.expr->Accept(*this);
auto val = val_recorder.ValOfPrevExpr();
auto int_expr = dynamic_cast<IntConstExprNode*>(case_stmt.expr.get());
assert(int_expr);
auto& int_expr = dynamic_cast<IntConstExprNode&>(*case_stmt.expr);

auto case_label = "case_" + std::to_string(int_expr->val);
auto case_label = "case_" + std::to_string(int_expr.val);
auto case_bb = llvm::BasicBlock::Create(context_, case_label,
builder_helper_.CurrFunc());

Expand Down Expand Up @@ -586,11 +584,10 @@ void LLVMIRGenerator::Visit(const ArrSubExprNode& arr_sub_expr) {
auto base_addr = val_to_id_addr.at(val);
auto arr_type = builder_helper_.GetLLVMType(*(arr_sub_expr.arr->type));
arr_sub_expr.index->Accept(*this);
auto index = dynamic_cast<IntConstExprNode*>(arr_sub_expr.index.get());
assert(index);
auto& index = dynamic_cast<IntConstExprNode&>(*arr_sub_expr.index);

auto res_addr = builder_.CreateConstInBoundsGEP2_32(arr_type, base_addr, 0,
(unsigned int)index->val);
(unsigned int)index.val);
auto res_val = builder_.CreateLoad(arr_type->getArrayElementType(), res_addr);
val_to_id_addr[res_val] = res_addr;
val_recorder.Record(res_val);
Expand Down Expand Up @@ -694,23 +691,21 @@ void LLVMIRGenerator::Visit(const PostfixArithExprNode& postfix_expr) {

auto one = llvm::ConstantInt::get(builder_.getInt32Ty(), 1, true);
auto res = builder_.CreateBinOp(arith_op, val, one);
const auto* id_expr = dynamic_cast<IdExprNode*>((postfix_expr.operand).get());
assert(id_expr);
builder_.CreateStore(res, id_to_val.at(id_expr->id));
const auto& id_expr = dynamic_cast<IdExprNode&>(*postfix_expr.operand);
builder_.CreateStore(res, id_to_val.at(id_expr.id));
}

void LLVMIRGenerator::Visit(const RecordMemExprNode& mem_expr) {
mem_expr.expr->Accept(*this);
auto val = val_recorder.ValOfPrevExpr();
auto base_addr = val_to_id_addr.at(val);
auto llvm_type = builder_helper_.GetLLVMType(*(mem_expr.expr->type));
auto* record_type = dynamic_cast<RecordType*>(mem_expr.expr->type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*mem_expr.expr->type);

auto res_addr = builder_.CreateStructGEP(
llvm_type, base_addr, record_type->MemberIndex(mem_expr.id));
llvm_type, base_addr, record_type.MemberIndex(mem_expr.id));
auto res_val = builder_.CreateLoad(
builder_helper_.GetLLVMType(record_type->MemberType(mem_expr.id)),
builder_helper_.GetLLVMType(record_type.MemberType(mem_expr.id)),
res_addr);
val_to_id_addr[res_val] = res_addr;
val_recorder.Record(res_val);
Expand Down Expand Up @@ -765,8 +760,8 @@ void LLVMIRGenerator::Visit(const UnaryExprNode& unary_expr) {
case UnaryOperator::kDeref: {
// Is function pointer.
if (unary_expr.operand->type->IsPtr() &&
dynamic_cast<PtrType*>((unary_expr.operand->type).get())
->base_type()
dynamic_cast<PtrType&>(*unary_expr.operand->type)
.base_type()
.IsFunc()) {
// No-op; the value is still the function itself.
break;
Expand Down
53 changes: 22 additions & 31 deletions src/qbe_ir_generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,13 @@ void QbeIrGenerator::Visit(const VarDeclNode& decl) {

void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {
if (arr_decl.is_global) {
const auto* arr_type = dynamic_cast<ArrType*>((arr_decl.type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_decl.type);
Write_("export data {} = align {} {{ ",
user_defined::GlobalPointer{arr_decl.id},
arr_type->element_type().size());
arr_type.element_type().size());

global_var_init_vals.clear();
auto arr_size = arr_type->len();
auto arr_size = arr_type.len();
auto init_len = arr_decl.init_list.size();
assert(init_len <= arr_size);
for (auto i = std::size_t{0}; i < init_len; ++i) {
Expand All @@ -243,14 +242,14 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {

// set remaining elements as 0
if (init_len < arr_size) {
Write_("z {}", (arr_size - init_len) * arr_type->element_type().size());
Write_("z {}", (arr_size - init_len) * arr_type.element_type().size());
}
Write_(" }}\n");
} else {
int base_addr_num = NextLocalNum();
assert(arr_decl.type->IsArr());
const auto* arr_type = dynamic_cast<ArrType*>((arr_decl.type).get());
auto element_size = arr_type->element_type().size();
const auto& arr_type = dynamic_cast<ArrType&>(*arr_decl.type);
auto element_size = arr_type.element_type().size();
WriteInstr_("{} =l alloc{} {}", FuncScopeTemp{base_addr_num}, element_size,
arr_decl.type->size());
id_to_num[arr_decl.id] = base_addr_num;
Expand All @@ -259,7 +258,7 @@ void QbeIrGenerator::Visit(const ArrDeclNode& arr_decl) {
// 6.7.9 Initialization
// 10. If an object that has automatic storage duration is not initialized
// explicitly, its value is indeterminate.
for (auto i = std::size_t{0}, e = arr_type->len(),
for (auto i = std::size_t{0}, e = arr_type.len(),
init_len = arr_decl.init_list.size();
i < e && init_len != 0; ++i) {
if (i < init_len) {
Expand Down Expand Up @@ -302,21 +301,20 @@ void QbeIrGenerator::Visit(const RecordVarDeclNode& record_var_decl) {
record_var_decl.type->size());
id_to_num[record_var_decl.id] = base_addr;

auto* record_type = dynamic_cast<RecordType*>(record_var_decl.type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*record_var_decl.type);
// NOTE: This predicate will make sure that we don't initialize members that
// exceed the total number of members in a record. Also, it gurantees
// exceed the total number of members in a record. Also, it guarantees
// that accessing element in the initializers will not go out of bound.
for (auto i = std::size_t{0}, e = record_var_decl.inits.size(),
slot_count = record_type->SlotCount();
slot_count = record_type.SlotCount();
i < slot_count && i < e; ++i) {
const auto& init = record_var_decl.inits.at(i);
init->Accept(*this);
const auto init_num = num_recorder.NumOfPrevExpr();

// res_addr = base_addr + offset
const int res_addr_num = NextLocalNum();
const auto offset = record_type->OffsetOf(i);
const auto offset = record_type.OffsetOf(i);
WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num},
FuncScopeTemp{base_addr}, offset);
WriteInstr_("storew {}, {}", FuncScopeTemp{init_num},
Expand Down Expand Up @@ -780,10 +778,9 @@ void QbeIrGenerator::Visit(const ArrSubExprNode& arr_sub_expr) {
// e.g. int a[3]
// a[1]'s offset = 1 * 4 (int size)
const int offset = NextLocalNum();
const auto* arr_type = dynamic_cast<ArrType*>((arr_sub_expr.arr->type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_sub_expr.arr->type);
WriteInstr_("{} =l mul {}, {}", FuncScopeTemp{offset},
FuncScopeTemp{extended_num}, arr_type->element_type().size());
FuncScopeTemp{extended_num}, arr_type.element_type().size());

// res_addr = base_addr + offset
const int res_addr_num = NextLocalNum();
Expand Down Expand Up @@ -896,22 +893,20 @@ void QbeIrGenerator::Visit(const PostfixArithExprNode& postfix_expr) {
// TODO: support pointer arithmetic
WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num},
GetBinaryOperator(arith_op), FuncScopeTemp{expr_num});
const auto* id_expr = dynamic_cast<IdExprNode*>((postfix_expr.operand).get());
assert(id_expr);
const auto& id_expr = dynamic_cast<IdExprNode&>(*postfix_expr.operand);
WriteInstr_("storew {}, {}", FuncScopeTemp{res_num},
FuncScopeTemp{id_to_num.at(id_expr->id)});
FuncScopeTemp{id_to_num.at(id_expr.id)});
}

void QbeIrGenerator::Visit(const RecordMemExprNode& mem_expr) {
mem_expr.expr->Accept(*this);
const auto num = num_recorder.NumOfPrevExpr();
const auto id_num = reg_num_to_id_num.at(num);
auto* record_type = dynamic_cast<RecordType*>(mem_expr.expr->type.get());
assert(record_type);
auto& record_type = dynamic_cast<RecordType&>(*mem_expr.expr->type);

const auto res_addr_num = NextLocalNum();
WriteInstr_("{} =l add {}, {}", FuncScopeTemp{res_addr_num},
FuncScopeTemp{id_num}, record_type->OffsetOf(mem_expr.id));
FuncScopeTemp{id_num}, record_type.OffsetOf(mem_expr.id));

const int res_num = NextLocalNum();
WriteInstr_("{} =w loadw {}", FuncScopeTemp{res_num},
Expand All @@ -933,11 +928,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
: BinaryOperator::kSub;
WriteInstr_("{} =w {} {}, 1", FuncScopeTemp{res_num},
GetBinaryOperator(arith_op), FuncScopeTemp{expr_num});
const auto* id_expr =
dynamic_cast<IdExprNode*>((unary_expr.operand).get());
assert(id_expr);
const auto& id_expr = dynamic_cast<IdExprNode&>(*unary_expr.operand);
WriteInstr_("storew {}, {}", FuncScopeTemp{res_num},
FuncScopeTemp{id_to_num.at(id_expr->id)});
FuncScopeTemp{id_to_num.at(id_expr.id)});
num_recorder.Record(res_num);
} break;
case UnaryOperator::kPos:
Expand Down Expand Up @@ -974,11 +967,9 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
// No-op; the function itself already evaluates to the address.
break;
}
const auto* id_expr =
dynamic_cast<IdExprNode*>((unary_expr.operand).get());
// NOTE: The operand of the address-of operator must be an lvalue, and we
// do not support arrays now, so it must have been backed by an id.
assert(id_expr);
assert(dynamic_cast<IdExprNode*>(unary_expr.operand.get()));
// The address of the id is the id itself.
const int reg_num = num_recorder.NumOfPrevExpr();
const int id_num = reg_num_to_id_num.at(reg_num);
Expand All @@ -993,8 +984,8 @@ void QbeIrGenerator::Visit(const UnaryExprNode& unary_expr) {
case UnaryOperator::kDeref: {
// Is function pointer.
if (unary_expr.operand->type->IsPtr() &&
dynamic_cast<PtrType*>((unary_expr.operand->type).get())
->base_type()
dynamic_cast<PtrType&>(*unary_expr.operand->type)
.base_type()
.IsFunc()) {
// No-op; the function itself also evaluates to the address.
break;
Expand Down
17 changes: 7 additions & 10 deletions src/type_checker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ void TypeChecker::Visit(RecordVarDeclNode& record_var_decl) {
// struct birth bd1 { .date = 1 }; // RecordVarDeclNode -> search type entry
// to update its type.
// record_type_id is "struct_birth" in the above example.
auto record_type_id =
dynamic_cast<RecordType*>(record_var_decl.type.get())->id();
auto record_type_id = dynamic_cast<RecordType&>(*record_var_decl.type).id();
auto record_type = env_.LookUpType(
MangleRecordTypeId(record_type_id, record_var_decl.type));
assert(record_type);
Expand Down Expand Up @@ -191,7 +190,7 @@ void TypeChecker::Visit(ParamNode& parameter) {
if (parameter.type->IsArr()) {
// Decay to simple pointer type.
parameter.type = std::make_unique<PtrType>(
dynamic_cast<ArrType*>(parameter.type.get())->element_type().Clone());
dynamic_cast<ArrType&>(*parameter.type).element_type().Clone());
} else if (parameter.type->IsFunc()) {
// Decay to function pointer type.
parameter.type = std::make_unique<PtrType>(parameter.type->Clone());
Expand Down Expand Up @@ -233,7 +232,7 @@ void TypeChecker::Visit(FuncDefNode& func_def) {
decayed_param_types.push_back(parameter->type->Clone());
}
auto return_type =
dynamic_cast<FuncType*>(func_def.type.get())->return_type().Clone();
dynamic_cast<FuncType&>(*func_def.type).return_type().Clone();
func_def.type = std::make_unique<FuncType>(std::move(return_type),
std::move(decayed_param_types));
auto symbol =
Expand Down Expand Up @@ -461,10 +460,9 @@ void TypeChecker::Visit(ArgExprNode& arg_expr) {
void TypeChecker::Visit(ArrSubExprNode& arr_sub_expr) {
arr_sub_expr.arr->Accept(*this);
arr_sub_expr.index->Accept(*this);
const auto* arr_type = dynamic_cast<ArrType*>((arr_sub_expr.arr->type).get());
assert(arr_type);
const auto& arr_type = dynamic_cast<ArrType&>(*arr_sub_expr.arr->type);
// arr_sub_expr should have the element type of the array.
arr_sub_expr.type = arr_type->element_type().Clone();
arr_sub_expr.type = arr_type.element_type().Clone();
arr_sub_expr.is_global = arr_sub_expr.arr->is_global;
}

Expand Down Expand Up @@ -565,9 +563,8 @@ void TypeChecker::Visit(UnaryExprNode& unary_expr) {
if (!unary_expr.operand->type->IsPtr()) {
// TODO: the operand of unary '*' shall have pointer type
}
unary_expr.type = dynamic_cast<PtrType*>(unary_expr.operand->type.get())
->base_type()
.Clone();
unary_expr.type =
dynamic_cast<PtrType&>(*unary_expr.operand->type).base_type().Clone();
break;
default:
unary_expr.type = unary_expr.operand->type->Clone();
Expand Down
Loading