From 82eddcdb1b7f6b4c6ea6d178d348c65389c6a309 Mon Sep 17 00:00:00 2001 From: Lai-YT <381xvmvbib@gmail.com> Date: Sat, 6 Jul 2024 00:28:35 +0800 Subject: [PATCH] Return constant reference on query of member type Since the `MemberType` function is a query function, it should not have to clone the type. Note that the function is no longer `noexcept` after the change. Originally, it returned the `unknown` type when the `id` was not a member of the record. However, since we cannot make up such an unknown type as a temporary object and return its reference (undefined behavior), we have to throw an exception. This is acceptable because the type checker should already ensure that the member exists, thus querying with a non-existing `id` indicates some internal errors. --- include/type.hpp | 13 +++++-------- src/llvm_ir_generator.cpp | 2 +- src/type.cpp | 16 ++++++---------- src/type_checker.cpp | 2 +- 4 files changed, 13 insertions(+), 20 deletions(-) diff --git a/include/type.hpp b/include/type.hpp index 4380aa1b..aa74580f 100644 --- a/include/type.hpp +++ b/include/type.hpp @@ -194,10 +194,9 @@ class RecordType : public Type { const noexcept = 0; /// @brief Checks if `id` is a member of the record type. virtual bool IsMember(const std::string& id) const noexcept = 0; - /// @return The type of a member in struct or union. The unknown type if the - /// `id` is not a member of the record type. - virtual std::unique_ptr MemberType( - const std::string& id) const noexcept = 0; + /// @return The type of a member in struct or union. + /// @throw `std::runtime_error` if the `id` is not a member of the record. + virtual const Type& MemberType(const std::string& id) const = 0; /// @note Every member in union shares the same index 0. /// @return The index of a member in struct or union. /// @throw `std::runtime_error` if the `id` is not a member of the record. @@ -227,8 +226,7 @@ class StructType : public RecordType { } std::string id() const noexcept override; bool IsMember(const std::string& id) const noexcept override; - std::unique_ptr MemberType( - const std::string& id) const noexcept override; + const Type& MemberType(const std::string& id) const override; std::size_t MemberIndex(const std::string& id) const override; std::size_t OffsetOf(const std::string& id) const override; std::size_t OffsetOf(std::size_t index) const override; @@ -260,8 +258,7 @@ class UnionType : public RecordType { } std::string id() const noexcept override; bool IsMember(const std::string& id) const noexcept override; - std::unique_ptr MemberType( - const std::string& id) const noexcept override; + const Type& MemberType(const std::string& id) const override; std::size_t MemberIndex(const std::string& id) const override; std::size_t OffsetOf(const std::string& id) const override; std::size_t OffsetOf(std::size_t index) const override; diff --git a/src/llvm_ir_generator.cpp b/src/llvm_ir_generator.cpp index a26759aa..78ea9511 100644 --- a/src/llvm_ir_generator.cpp +++ b/src/llvm_ir_generator.cpp @@ -671,7 +671,7 @@ void LLVMIRGenerator::Visit(const RecordMemExprNode& mem_expr) { auto res_addr = builder_.CreateStructGEP( llvm_type, base_addr, record_type->MemberIndex(mem_expr.id)); auto res_val = builder_.CreateLoad( - builder_helper_.GetLLVMType(*(record_type->MemberType(mem_expr.id))), + builder_helper_.GetLLVMType(record_type->MemberType(mem_expr.id)), res_addr); val_to_id_addr[res_val] = res_addr; val_recorder.Record(res_val); diff --git a/src/type.cpp b/src/type.cpp index f81c2574..63e257cd 100644 --- a/src/type.cpp +++ b/src/type.cpp @@ -171,15 +171,13 @@ bool StructType::IsMember(const std::string& id) const noexcept { return false; } -std::unique_ptr StructType::MemberType( - const std::string& id) const noexcept { +const Type& StructType::MemberType(const std::string& id) const { for (const auto& field : fields_) { if (field->id == id) { - return field->type->Clone(); + return *field->type; } } - - return std::make_unique(PrimitiveType::kUnknown); + throw std::runtime_error{"member not found in struct!"}; } std::size_t StructType::MemberIndex(const std::string& id) const { @@ -277,15 +275,13 @@ bool UnionType::IsMember(const std::string& id) const noexcept { return false; } -std::unique_ptr UnionType::MemberType( - const std::string& id) const noexcept { +const Type& UnionType::MemberType(const std::string& id) const { for (const auto& field : fields_) { if (field->id == id) { - return field->type->Clone(); + return *field->type; } } - - return std::make_unique(PrimitiveType::kUnknown); + throw std::runtime_error{"member not found in union!"}; } std::size_t UnionType::MemberIndex(const std::string& id) const { diff --git a/src/type_checker.cpp b/src/type_checker.cpp index 792adbf5..b43d6813 100644 --- a/src/type_checker.cpp +++ b/src/type_checker.cpp @@ -497,7 +497,7 @@ void TypeChecker::Visit(RecordMemExprNode& mem_expr) { if (auto* record_type = dynamic_cast((mem_expr.expr->type).get())) { if (record_type->IsMember(mem_expr.id)) { - mem_expr.type = record_type->MemberType(mem_expr.id); + mem_expr.type = record_type->MemberType(mem_expr.id).Clone(); } else { assert(false); // TODO: Throw error if mem_expr.id is not a symbol's member.