diff --git a/CMakeLists.txt b/CMakeLists.txt index 036e657b..03a810e8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -173,10 +173,14 @@ export(EXPORT ${PROJECT_NAME}_Targets # Register package in the User Package Registry export(PACKAGE trieste) +# ############################################# +# # Add core Trieste tests +enable_testing() +add_subdirectory(test) + # ############################################# # # Add samples if(TRIESTE_BUILD_SAMPLES) - enable_testing() add_subdirectory(samples/infix) endif() diff --git a/include/trieste/ast.h b/include/trieste/ast.h index 453748e9..f7b4a342 100644 --- a/include/trieste/ast.h +++ b/include/trieste/ast.h @@ -2,6 +2,7 @@ // SPDX-License-Identifier: MIT #pragma once +#include "intrusive_ptr.h" #include "token.h" #include @@ -33,10 +34,10 @@ namespace trieste using Nodes = std::vector; using NodeIt = Nodes::iterator; - using NodeSet = std::set>; + using NodeSet = std::set; template - using NodeMap = std::map>; + using NodeMap = std::map; #ifdef TRIESTE_USE_CXX17 class NodeRange @@ -105,7 +106,7 @@ namespace trieste using NodeRange = std::span; #endif - class SymtabDef + class SymtabDef final : public intrusive_refcounted { friend class NodeDef; @@ -134,7 +135,7 @@ namespace trieste void str(std::ostream& out, size_t level); }; - using Symtab = std::shared_ptr; + using Symtab = intrusive_ptr; struct Index { @@ -181,7 +182,7 @@ namespace trieste } }; - class NodeDef : public std::enable_shared_from_this + class NodeDef final : public intrusive_refcounted { private: Token type_; @@ -195,7 +196,7 @@ namespace trieste : type_(type), location_(location), parent_(nullptr) { if (type_ & flag::symtab) - symtab_ = std::make_shared(); + symtab_ = Symtab::make(); } void add_flags() @@ -291,12 +292,12 @@ namespace trieste static Node create(const Token& type) { - return std::shared_ptr(new NodeDef(type, {nullptr, 0, 0})); + return Node(new NodeDef(type, Location{nullptr, 0, 0})); } static Node create(const Token& type, Location location) { - return std::shared_ptr(new NodeDef(type, location)); + return Node(new NodeDef(type, location)); } static Node create(const Token& type, NodeRange range) @@ -304,7 +305,7 @@ namespace trieste if (range.empty()) return create(type); - return std::shared_ptr( + return Node( new NodeDef(type, range.front()->location_ * range.back()->location_)); } @@ -340,7 +341,7 @@ namespace trieste while (p) { if (p->type_.in(list)) - return p->shared_from_this(); + return p->intrusive_ptr_from_this(); p = p->parent_; } @@ -406,7 +407,7 @@ namespace trieste auto find_first(Token token, NodeIt begin) { - assert((*begin)->parent() == this); + assert((*begin)->parent() == this); return std::find_if( begin, children.end(), [token](auto& n) { return n->type() == token; }); } @@ -548,7 +549,7 @@ namespace trieste while (p) { - auto node = p->shared_from_this(); + auto node = p->intrusive_ptr_from_this(); if (node->symtab_) return node; @@ -655,7 +656,7 @@ namespace trieste throw std::runtime_error("No symbol table"); auto& entry = st->symtab_->symbols[loc]; - entry.push_back(shared_from_this()); + entry.push_back(intrusive_ptr_from_this()); // If there are multiple definitions, none can be shadowing. return (entry.size() == 1) || @@ -671,7 +672,7 @@ namespace trieste if (!st) throw std::runtime_error("No symbol table"); - st->symtab_->includes.emplace_back(shared_from_this()); + st->symtab_->includes.emplace_back(intrusive_ptr_from_this()); } Location fresh(const Location& prefix = {}) @@ -752,10 +753,10 @@ namespace trieste // If p and q are the same, then one is contained within the other. if (p == q) - return p->shared_from_this(); + return p->intrusive_ptr_from_this(); // Otherwise return the common parent. - return p->parent_->shared_from_this(); + return p->parent_->intrusive_ptr_from_this(); } bool precedes(Node node) @@ -775,8 +776,8 @@ namespace trieste // Check that p is to the left of q. auto parent = p->parent_; - return parent->find(p->shared_from_this()) < - parent->find(q->shared_from_this()); + return parent->find(p->intrusive_ptr_from_this()) < + parent->find(q->intrusive_ptr_from_this()); } void str(std::ostream& out, size_t level = 0) const @@ -846,7 +847,7 @@ namespace trieste template SNMALLOC_FAST_PATH void traverse(Pre pre, Post post = NopPost()) { - Node root = shared_from_this(); + Node root = intrusive_ptr_from_this(); if (!pre(root)) return; @@ -936,6 +937,18 @@ namespace trieste } }; + constexpr void + intrusive_refcounted_traits::intrusive_inc_ref(NodeDef* node) + { + node->intrusive_inc_ref(); + } + + constexpr void + intrusive_refcounted_traits::intrusive_dec_ref(NodeDef* node) + { + node->intrusive_dec_ref(); + } + inline TokenDef::operator Node() const { return NodeDef::create(Token(*this)); diff --git a/include/trieste/intrusive_ptr.h b/include/trieste/intrusive_ptr.h new file mode 100644 index 00000000..bdf8ad36 --- /dev/null +++ b/include/trieste/intrusive_ptr.h @@ -0,0 +1,409 @@ +#pragma once + +#include "snmalloc/ds_core/defines.h" + +#include +#include +#include +#include +#include +#include + +namespace trieste +{ + namespace detail + { + // In principle, std::atomic should not be copied. + // It should be a single object that is pointer-to and manipulated by + // multiple threads. For refcounts however, it should be possible to copy a + // refcounted object. The catch is that _everything but the refcount should + // be copied_. The copy constructors here will just set the new refcount to + // 0, as if the object was constructed from scratch, so different + // intrusive_ptr can take ownership of the new object. + struct copyable_refcount final + { + private: + // The refcount here starts at 0, not 1 like in other reference counting + // systems. It's because we're not even sure we're reference counting a + // heap allocated object at all. + // + // The reference count is embedded into a user-allocatable object that can + // (and does in this codebase) live on the stack in some cases. If a + // pointer to an intrusive_refcounted is given to an intrusive_ptr, then + // its refcount is incremented to 1, and it becomes managed as a + // reference-counted object. If not, it is convenient to start and keep + // the refcount at 0 - no intrusive_ptr should point to a stack-allocated + // intrusive_refcounted. Also, we assert that the end refcount of a + // destroyed intrusive_refcounted is 0, and starting at 1 would prevent + // that assertion from holding in general. + static constexpr size_t refcount_init = 0; + std::atomic value; + + public: + constexpr copyable_refcount(size_t value_) : value{value_} {} + + constexpr copyable_refcount() : value{refcount_init} {} + constexpr copyable_refcount(const copyable_refcount&) + : value{refcount_init} + {} + + operator size_t() const + { + return value; + } + + copyable_refcount& operator+=(size_t inc) + { + value += inc; + return *this; + } + + size_t fetch_sub(size_t dec) + { + return value.fetch_sub(dec); + } + }; + } + + // These traits are an indirect helper for incrementing and decrementing + // refcounts on intrusive_refcounted objects. Usually, it is fine to just call + // the intrusive_inc_ref and intrusive_dec_ref methods on T* directly, but + // inflexibly doing that all the time interacts poorly with cases where T is + // an incomplete type. + // + // Consider this example, where many intrusive_ptr methods must be compiled in + // a context where T is an incomplete type: + // ```cpp + // class T; + // using Handle = intrusive_ptr; + // void foo(Handle /*...*/) { /*...*/ } + // ``` + // + // Without any special handling, this would try to compile both method calls + // and destructor calls on an incomplete T, which is bad. The solution is to + // specialize these traits with forward-declared functions: + // ```cpp + // template<> + // struct intrusive_refcounted_traits + // { + // static constexpr void intrusive_inc_ref(T* ptr); + // static constexpr void intrusive_dec_ref(T* ptr); + // }; + // ``` + // + // Then, you can implement the function bodies later on in your code when T is + // complete, and using Handle at any point will be fine because the 2 + // functions prototypes prevent Handle from actually trying to compile the + // method calls it needs too early. + template + struct intrusive_refcounted_traits + { + static constexpr void intrusive_inc_ref(T* ptr) + { + ptr->intrusive_inc_ref(); + } + + static constexpr void intrusive_dec_ref(T* ptr) + { + ptr->intrusive_dec_ref(); + } + }; + + template + struct intrusive_ptr final + { + private: + T* ptr; + + constexpr void inc_ref() const + { + if (ptr) + { + intrusive_refcounted_traits::intrusive_inc_ref(ptr); + } + } + + constexpr void dec_ref() + { + if (ptr) + { + intrusive_refcounted_traits::intrusive_dec_ref(ptr); + ptr = nullptr; + } + } + + public: + template + static intrusive_ptr make(Args&&... args) + { + return intrusive_ptr(new T(std::forward(args)...)); + } + + constexpr intrusive_ptr() : ptr{nullptr} {} + + constexpr intrusive_ptr(std::nullptr_t) : ptr{nullptr} {} + + constexpr explicit intrusive_ptr(T* ptr_) : ptr{ptr_} + { + inc_ref(); + } + + template + constexpr intrusive_ptr(const intrusive_ptr& other) : ptr{other.ptr} + { + inc_ref(); + } + + template + constexpr intrusive_ptr(intrusive_ptr&& other) : ptr{other.release()} + {} + + constexpr intrusive_ptr(const intrusive_ptr& other) : ptr{other.ptr} + { + inc_ref(); + } + + constexpr intrusive_ptr(intrusive_ptr&& other) : ptr{other.release()} {} + + constexpr intrusive_ptr& operator=(const intrusive_ptr& other) + { + // Self-assignment case, don't bother touching refcounts then + if (ptr == other.ptr) + { + return *this; + } + // Increment other's refcount before copying the ptr + other.inc_ref(); + + intrusive_ptr tmp; + // Don't actually inc_ref, but putting old ptr in tmp lets us leverage the + // built in dec_ref with null checks below. + tmp.ptr = ptr; + + ptr = other.ptr; + // tmp gets dec_ref here, potentially destroying the value at old ptr + return *this; + } + + constexpr intrusive_ptr& operator=(intrusive_ptr&& other) + { + intrusive_ptr old; + old.ptr = ptr; + ptr = other.ptr; + other.ptr = nullptr; + return *this; + } + + constexpr void swap(intrusive_ptr& other) + { + std::swap(ptr, other.ptr); + } + + constexpr void reset() + { + dec_ref(); + } + + constexpr T* get() const + { + return ptr; + } + + constexpr T* operator->() const + { + return get(); + } + + constexpr T& operator*() const + { + return *get(); + } + + constexpr operator bool() const + { + return ptr; + } + + constexpr T* release() + { + auto p = get(); + ptr = nullptr; + return p; + } + + ~intrusive_ptr() + { + dec_ref(); + } + + friend std::hash>; + }; + + template + struct intrusive_refcounted + { + private: + // See docs on this type for an explanation of its unusual refcounting + // semantics. + // + // Note: this is a separate type because it allows subclasses to ignore it. + // Any generated copy/default constructors will call into it properly, and + // if a hand-written copy constructor ignores it, it will be silently + // default initialized. It is always necessary to make a fresh refcount for + // a fresh object, so it's fine to ignore copy semantics here - it makes no + // difference to what will happen. + detail::copyable_refcount intrusive_refcount; + + constexpr void intrusive_inc_ref() + { + intrusive_refcount += 1; + } + + // It's better to have the non-null case dec_ref code all in one place, + // because it's long for something that might be pasted over 10x + // into functions that use intrusive_ptr a lot. + SNMALLOC_SLOW_PATH + constexpr void intrusive_dec_ref() + { + // Atomically subtract 1 from refcount and get the _old value_. + size_t prev_rc = intrusive_refcount.fetch_sub(1); + // If the value _was_ 0, we just did a negative wrap-around to + // max(size_t). We should stop now and think about how we got here. + assert(prev_rc > 0); + + // If the value was 1, it is now 0 and we can clean up. + if (prev_rc == 1) + { + delete static_cast(this); + } + } + + public: + template + friend struct intrusive_refcounted_traits; + + constexpr intrusive_ptr intrusive_ptr_from_this() + { + return intrusive_ptr{static_cast(this)}; + } + + ~intrusive_refcounted() + { + assert(intrusive_refcount == 0); + } + }; + + template + constexpr intrusive_ptr static_pointer_cast(const intrusive_ptr& ptr) + { + return intrusive_ptr(static_cast(ptr.get())); + } + + template + constexpr intrusive_ptr dynamic_pointer_cast(const intrusive_ptr& ptr) + { + // nullptr dynamic_cast case handled: constructor tolerates it anyway + return intrusive_ptr(dynamic_cast(ptr.get())); + } + + template + constexpr intrusive_ptr const_pointer_cast(const intrusive_ptr& ptr) + { + return intrusive_ptr(const_cast(ptr.get())); + } + + // impl note: + // It is important that these functions are non-member template functions. + // If you make them just member functions, they clash with operator==(Node, + // const Token&) defined elsewhere. + + template + constexpr bool + operator==(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() == rhs.get(); + } + + template + constexpr bool operator==(const intrusive_ptr& lhs, std::nullptr_t) + { + return lhs.get() == nullptr; + } + + template + constexpr bool operator==(std::nullptr_t, const intrusive_ptr& rhs) + { + return nullptr == rhs.get(); + } + + template + constexpr bool + operator!=(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() != rhs.get(); + } + + template + constexpr bool operator!=(const intrusive_ptr& lhs, std::nullptr_t) + { + return lhs.get() != nullptr; + } + + template + constexpr bool operator!=(std::nullptr_t, const intrusive_ptr& rhs) + { + return nullptr != rhs.get(); + } + + template + constexpr bool + operator<(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() < rhs.get(); + } + + template + constexpr bool + operator>(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() > rhs.get(); + } + + template + constexpr bool + operator<=(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() <= rhs.get(); + } + + template + constexpr bool + operator>=(const intrusive_ptr& lhs, const intrusive_ptr& rhs) + { + return lhs.get() >= rhs.get(); + } + + template + std::ostream& operator<<(std::ostream& os, const intrusive_ptr ptr) + { + return os << ptr.get(); + } +} + +namespace std +{ + template + struct hash> + { + size_t operator()(const trieste::intrusive_ptr ptr) const + { + return std::hash{}(ptr.ptr); + } + }; + + template + void swap(trieste::intrusive_ptr& lhs, trieste::intrusive_ptr& rhs) + { + lhs.swap(rhs); + } +} diff --git a/include/trieste/parse.h b/include/trieste/parse.h index fc970ac0..eb181723 100644 --- a/include/trieste/parse.h +++ b/include/trieste/parse.h @@ -7,6 +7,7 @@ #include "gen.h" #include "logging.h" #include "regex.h" +#include "trieste/intrusive_ptr.h" #include "wf.h" #include @@ -22,7 +23,7 @@ namespace trieste class Make; using ParseEffect = std::function; - class RuleDef + class RuleDef final : public intrusive_refcounted { friend class trieste::Parse; @@ -41,7 +42,7 @@ namespace trieste {} }; - using Rule = std::shared_ptr; + using Rule = intrusive_ptr; class Make { @@ -139,7 +140,7 @@ namespace trieste while (node->parent()->type().in(skip)) { extend(); - node = node->parent()->shared_from_this(); + node = node->parent()->intrusive_ptr_from_this(); } extend(); @@ -147,7 +148,7 @@ namespace trieste if (p == type) { - node = p->shared_from_this(); + node = p->intrusive_ptr_from_this(); } else { @@ -215,7 +216,7 @@ namespace trieste { extend(); - node = node->parent()->shared_from_this(); + node = node->parent()->intrusive_ptr_from_this(); return true; } @@ -244,7 +245,7 @@ namespace trieste { node->push_back(make_error(node->location(), "this is unclosed")); term(); - node = node->parent()->shared_from_this(); + node = node->parent()->intrusive_ptr_from_this(); term(); } @@ -533,13 +534,13 @@ namespace trieste inline detail::Rule operator>>(detail::Located s, detail::ParseEffect effect) { - return std::make_shared(s, effect); + return detail::Rule::make(s, effect); } inline detail::Rule operator>>(detail::Located s, detail::ParseEffect effect) { - return std::make_shared(s, effect); + return detail::Rule::make(s, effect); } inline std::pair diff --git a/include/trieste/pass.h b/include/trieste/pass.h index 1e4073eb..f8323f3c 100644 --- a/include/trieste/pass.h +++ b/include/trieste/pass.h @@ -2,6 +2,7 @@ #include "defaultmap.h" #include "rewrite.h" +#include "trieste/intrusive_ptr.h" #include "wf.h" #include @@ -17,9 +18,9 @@ namespace trieste } class PassDef; - using Pass = std::shared_ptr; + using Pass = intrusive_ptr; - class PassDef + class PassDef : public intrusive_refcounted { public: using F = std::function; @@ -88,7 +89,7 @@ namespace trieste operator Pass() const { - return std::make_shared(std::move(*this)); + return Pass::make(std::move(*this)); } const std::string& name() diff --git a/include/trieste/regex.h b/include/trieste/regex.h index 36c99f25..4237ff20 100644 --- a/include/trieste/regex.h +++ b/include/trieste/regex.h @@ -3,7 +3,7 @@ #pragma once #include "logging.h" -#include "source.h" +#include "ast.h" #include @@ -186,7 +186,7 @@ namespace trieste if (!parent) return ast; - ast = parent->shared_from_this(); + ast = parent->intrusive_ptr_from_this(); } } diff --git a/include/trieste/rewrite.h b/include/trieste/rewrite.h index 58a7fcd9..9d7c6352 100644 --- a/include/trieste/rewrite.h +++ b/include/trieste/rewrite.h @@ -6,11 +6,11 @@ #include "debug.h" #include "regex.h" #include "token.h" +#include "trieste/intrusive_ptr.h" #include #include #include -#include #include namespace trieste @@ -277,7 +277,8 @@ namespace trieste return FastPattern(new_first, new_parent, new_pass_through); } - static FastPattern SNMALLOC_SLOW_PATH match_opt(const FastPattern& pattern) + static FastPattern SNMALLOC_SLOW_PATH + match_opt(const FastPattern& pattern) { if (pattern.any_first()) return pattern; @@ -297,9 +298,9 @@ namespace trieste }; class PatternDef; - using PatternPtr = std::shared_ptr; + using PatternPtr = intrusive_ptr; - class PatternDef + class PatternDef : public intrusive_refcounted { PatternPtr continuation{}; @@ -369,8 +370,6 @@ namespace trieste } }; - using PatternPtr = std::shared_ptr; - class Cap : public PatternDef { private: @@ -389,7 +388,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -411,7 +410,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -435,7 +434,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -477,7 +476,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -508,7 +507,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -539,7 +538,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -579,7 +578,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -618,7 +617,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -657,7 +656,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -694,14 +693,14 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override { // Rep(Inside) -> InsideStar if (no_continuation()) - return std::make_shared>(types); + return intrusive_ptr>::make(types); return {}; } @@ -724,7 +723,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -745,7 +744,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -778,7 +777,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -812,7 +811,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -843,7 +842,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } PatternPtr custom_rep() override @@ -873,7 +872,7 @@ namespace trieste PatternPtr clone() const& override { - return std::make_shared(*this); + return intrusive_ptr::make(*this); } bool match(NodeIt& it, const Node& parent, Match& match) const& override @@ -917,29 +916,31 @@ namespace trieste Pattern operator()(F&& action) const { return { - std::make_shared>(std::forward(action), pattern), + intrusive_ptr>::make(std::forward(action), pattern), fast_pattern}; } Pattern operator[](const Token& name) const { - return {std::make_shared(name, pattern), fast_pattern}; + return {intrusive_ptr::make(name, pattern), fast_pattern}; } Pattern operator~() const { return { - std::make_shared(pattern), FastPattern::match_opt(fast_pattern)}; + intrusive_ptr::make(pattern), + FastPattern::match_opt(fast_pattern)}; } Pattern operator++() const { - return {std::make_shared(pattern), FastPattern::match_pred()}; + return {intrusive_ptr::make(pattern), FastPattern::match_pred()}; } Pattern operator--() const { - return {std::make_shared(pattern), FastPattern::match_pred()}; + return { + intrusive_ptr::make(pattern), FastPattern::match_pred()}; } Pattern operator++(int) const @@ -951,12 +952,13 @@ namespace trieste return {result, FastPattern::match_any()}; return { - std::make_shared(pattern), FastPattern::match_opt(fast_pattern)}; + intrusive_ptr::make(pattern), + FastPattern::match_opt(fast_pattern)}; } Pattern operator!() const { - return {std::make_shared(pattern), FastPattern::match_pred()}; + return {intrusive_ptr::make(pattern), FastPattern::match_pred()}; } Pattern operator*(Pattern rhs) const @@ -977,23 +979,24 @@ namespace trieste tokens.insert(tokens.end(), lhs_tokens.begin(), lhs_tokens.end()); tokens.insert(tokens.end(), rhs_tokens.begin(), rhs_tokens.end()); return { - std::make_shared(tokens), + intrusive_ptr::make(tokens), FastPattern::match_choice(fast_pattern, rhs.fast_pattern)}; } if (pattern->has_captures()) return { - std::make_shared>(pattern, rhs.pattern), + intrusive_ptr>::make(pattern, rhs.pattern), FastPattern::match_choice(fast_pattern, rhs.fast_pattern)}; else return { - std::make_shared>(pattern, rhs.pattern), + intrusive_ptr>::make(pattern, rhs.pattern), FastPattern::match_choice(fast_pattern, rhs.fast_pattern)}; } Pattern operator<<(Pattern rhs) const { - return {std::make_shared(pattern, rhs.pattern), fast_pattern}; + return { + intrusive_ptr::make(pattern, rhs.pattern), fast_pattern}; } const std::set& get_starts() const @@ -1040,17 +1043,17 @@ namespace trieste } inline const auto Any = detail::Pattern( - std::make_shared(), detail::FastPattern::match_any()); + intrusive_ptr::make(), detail::FastPattern::match_any()); inline const auto Start = detail::Pattern( - std::make_shared(), detail::FastPattern::match_pred()); + intrusive_ptr::make(), detail::FastPattern::match_pred()); inline const auto End = detail::Pattern( - std::make_shared(), detail::FastPattern::match_pred()); + intrusive_ptr::make(), detail::FastPattern::match_pred()); inline detail::Pattern T(const Token& type) { std::vector types = {type}; return detail::Pattern( - std::make_shared(types), + intrusive_ptr::make(types), detail::FastPattern::match_token({type})); } @@ -1060,14 +1063,14 @@ namespace trieste { std::vector types_ = {type1, type2, types...}; return detail::Pattern( - std::make_shared(types_), + intrusive_ptr::make(types_), detail::FastPattern::match_token({type1, type2, types...})); } inline detail::Pattern T(const Token& type, const std::string& r) { return detail::Pattern( - std::make_shared(type, r), + intrusive_ptr::make(type, r), detail::FastPattern::match_token({type})); } @@ -1076,7 +1079,7 @@ namespace trieste { std::array types_ = {type1, types...}; return detail::Pattern( - std::make_shared>(types_), + intrusive_ptr>::make(types_), detail::FastPattern::match_parent({type1, types...})); } diff --git a/include/trieste/source.h b/include/trieste/source.h index 1346851e..53245165 100644 --- a/include/trieste/source.h +++ b/include/trieste/source.h @@ -2,26 +2,25 @@ // SPDX-License-Identifier: MIT #pragma once +#include "intrusive_ptr.h" + #include #include #include #include #include -#include #include #include -#include #include namespace trieste { class SourceDef; struct Location; - class NodeDef; - using Source = std::shared_ptr; - using Node = std::shared_ptr; - class SourceDef + using Source = intrusive_ptr; + + class SourceDef final : public intrusive_refcounted { private: std::string origin_; @@ -39,7 +38,7 @@ namespace trieste auto size = f.tellg(); f.seekg(0, std::ios::beg); - auto source = std::make_shared(); + auto source = Source::make(); source->origin_ = std::filesystem::relative(file).string(); source->contents.resize(static_cast(size)); f.read(&source->contents[0], size); @@ -53,7 +52,7 @@ namespace trieste static Source synthetic(const std::string& contents) { - auto source = std::make_shared(); + auto source = Source::make(); source->contents = contents; source->find_lines(); return source; diff --git a/include/trieste/token.h b/include/trieste/token.h index e751d17e..b1c07090 100644 --- a/include/trieste/token.h +++ b/include/trieste/token.h @@ -10,6 +10,24 @@ namespace trieste { + class NodeDef; + + // Certain uses of the Node alias before the full definition of NodeDef can + // cause incomplete type errors, so this manually relocates the problematic + // code to after NodeDef is fully defined. See the docs on the specialized + // trait for details. + // + // Note: this is only needed by our C++17 implementation of NodeRange (in + // ast.h). If we stop supporting C++17, this can be deleted. + template<> + struct intrusive_refcounted_traits + { + static constexpr void intrusive_inc_ref(NodeDef*); + static constexpr void intrusive_dec_ref(NodeDef*); + }; + + using Node = intrusive_ptr; + struct TokenDef; struct Token; diff --git a/include/trieste/writer.h b/include/trieste/writer.h index df076298..ba7d2d6a 100644 --- a/include/trieste/writer.h +++ b/include/trieste/writer.h @@ -3,6 +3,7 @@ #pragma once #include "passes.h" +#include "trieste/intrusive_ptr.h" #include "trieste/wf.h" #include @@ -10,9 +11,9 @@ namespace trieste { class DestinationDef; - using Destination = std::shared_ptr; + using Destination = intrusive_ptr; - class DestinationDef + class DestinationDef : public intrusive_refcounted { private: enum class Mode @@ -129,7 +130,7 @@ namespace trieste static Destination dir(const std::filesystem::path& path) { - auto d = std::make_shared(); + auto d = Destination::make(); d->mode_ = Mode::FileSystem; d->path_ = path; return d; @@ -137,7 +138,7 @@ namespace trieste static Destination console() { - auto d = std::make_shared(); + auto d = Destination::make(); d->mode_ = Mode::Console; d->path_ = "."; return d; @@ -145,7 +146,7 @@ namespace trieste static Destination synthetic() { - auto d = std::make_shared(); + auto d = Destination::make(); d->mode_ = Mode::Synthetic; d->path_ = "."; return d; diff --git a/parsers/yaml/reader.cc b/parsers/yaml/reader.cc index e1dc7312..5cdb861a 100644 --- a/parsers/yaml/reader.cc +++ b/parsers/yaml/reader.cc @@ -1,11 +1,8 @@ #include "internal.h" -#include "trieste/pass.h" -#include "trieste/rewrite.h" -#include "trieste/source.h" -#include "trieste/token.h" #include #include +#include namespace { @@ -1157,8 +1154,8 @@ namespace (T(MaybeDirective)[MaybeDirective] * ~T(NewLine) * End)([](auto& n) { Node dir = n.front(); - Node doc = dir->parent()->parent()->shared_from_this(); - Node stream = doc->parent()->shared_from_this(); + Node doc = dir->parent()->parent()->intrusive_ptr_from_this(); + Node stream = doc->parent()->intrusive_ptr_from_this(); return stream->find(doc) < stream->end() - 1; }) >> [](Match& _) { @@ -1866,7 +1863,7 @@ namespace In(SequenceIndent) * T(Indent, BlockIndent)[Indent]([](auto& n) { Node indent = n.front(); - Node parent = indent->parent()->shared_from_this(); + Node parent = indent->parent()->intrusive_ptr_from_this(); return same_indent(parent, indent); }) >> [](Match& _) -> Node { @@ -2187,7 +2184,7 @@ namespace (T(ValueGroup) << (T(FlowMapping, FlowSequence))[Flow])( [](auto& n) { Node group = n.front(); - Node item = group->parent()->shared_from_this(); + Node item = group->parent()->intrusive_ptr_from_this(); Node flow = group->front(); std::size_t item_indent = item->location().linecol().second; std::size_t flow_indent = min_indent(flow); diff --git a/parsers/yaml/writer.cc b/parsers/yaml/writer.cc index e62afd64..958d17c3 100644 --- a/parsers/yaml/writer.cc +++ b/parsers/yaml/writer.cc @@ -1,10 +1,9 @@ #include "internal.h" -#include "trieste/rewrite.h" -#include "trieste/utf8.h" -#include "trieste/wf.h" #include "yaml.h" #include +#include +#include namespace { @@ -116,7 +115,7 @@ namespace if (current->in({MappingItem, FlowMappingItem})) { - newline = newline || !is_complex(current->shared_from_this()); + newline = newline || !is_complex(current->intrusive_ptr_from_this()); } return newline && !current->in({Sequence, FlowSequence}); diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt new file mode 100644 index 00000000..1dfae8d3 --- /dev/null +++ b/test/CMakeLists.txt @@ -0,0 +1,14 @@ + +add_executable(trieste_intrusive_ptr_test + intrusive_ptr_test.cc +) +target_link_libraries(trieste_intrusive_ptr_test trieste::trieste) + +# This test might not make so much sense without asan enabled, but might as well +# check that the test compiles and doesn't crash on other compilers. +if(CMAKE_CXX_COMPILER_ID STREQUAL "Clang" AND NOT TRIESTE_SANITIZE) + target_compile_options(trieste_intrusive_ptr_test PUBLIC -g -fsanitize=thread) + target_link_libraries(trieste_intrusive_ptr_test -fsanitize=thread) +endif() + +add_test(NAME trieste_intrusive_ptr_test COMMAND trieste_intrusive_ptr_test) diff --git a/test/intrusive_ptr_test.cc b/test/intrusive_ptr_test.cc new file mode 100644 index 00000000..62124bc2 --- /dev/null +++ b/test/intrusive_ptr_test.cc @@ -0,0 +1,165 @@ +#include +#include +#include +#include +#include + +struct Dummy : public trieste::intrusive_refcounted +{ + size_t tag; + + Dummy(size_t tag_) : tag{tag_} {} +}; + +using ptr_t = trieste::intrusive_ptr; +using ActionFn = ptr_t(ptr_t); + +std::vector actions{ + [](ptr_t ptr) -> ptr_t { + if (ptr == nullptr) + { + std::cout << "Should only be setting to nullptr once per thread!" + << std::endl; + std::abort(); + } + return nullptr; // dec_ref on this ptr + }, + [](ptr_t ptr) { + auto tmp = std::move(ptr); + return tmp; + }, + [](ptr_t ptr) { + auto tmp = ptr; + return ptr; + }, + [](ptr_t ptr) { + auto& alias = ptr; + alias = ptr; + return ptr; + }, +}; + +struct Behavior +{ + size_t action_idx; + size_t ptr_idx; + + bool operator<(const Behavior& other) const + { + return std::pair{action_idx, ptr_idx} < + std::pair{other.action_idx, other.ptr_idx}; + } +}; + +struct Test +{ + size_t ptr_count; + std::vector> thread_behaviors; + + void run() const + { + // Each thread gets its own copy of an array of N pointers, where every + // thread shares refcounts with every other thread. + std::vector> ptrs_per_thread; + ptrs_per_thread.emplace_back(); + for (size_t i = 0; i < ptr_count; ++i) + { + ptrs_per_thread.front().push_back(ptr_t::make(i)); + } + while (ptrs_per_thread.size() < thread_behaviors.size()) + { + ptrs_per_thread.push_back(ptrs_per_thread.back()); + } + + std::vector threads; + for (size_t i = 0; i < thread_behaviors.size(); ++i) + { + threads.emplace_back([&, i]() { + for (auto& behavior : thread_behaviors.at(i)) + { + auto& ptr = ptrs_per_thread.at(i).at(behavior.ptr_idx); + ptr = actions[behavior.action_idx](ptr); + } + }); + } + + for (auto& thread : threads) + { + thread.join(); + } + + // Sanity check: every thread should be setting their ptr to nullptr at some + // point + for (const auto& ptrs : ptrs_per_thread) + { + for (const auto& ptr : ptrs) + { + if (ptr != nullptr) + { + std::cout << "non-null ptr!" << std::endl; + std::abort(); + } + } + } + } +}; + +std::vector +build_tests(size_t ptr_count, size_t thread_count, size_t permutations) +{ + std::vector all_behaviors; + for (size_t action_idx = 0; action_idx < actions.size(); ++action_idx) + { + for (size_t ptr_idx = 0; ptr_idx < ptr_count; ++ptr_idx) + { + all_behaviors.push_back({ + action_idx, + ptr_idx, + }); + } + } + + std::vector tests = {{ptr_count, {}}}; + for (size_t i = 0; i < thread_count; ++i) + { + std::vector next_tests; + for (const auto& test : tests) + { + // Allow adding some extra permutations if you think you're stuck at the + // first few. + for (size_t permutation_idx = 0; permutation_idx < permutations; + ++permutation_idx) + { + auto mod_test = test; + mod_test.thread_behaviors.push_back(all_behaviors); + next_tests.push_back(mod_test); + + // Unconditionally permute the behaviors. We're not looking for total + // coverage, just variety. + std::next_permutation(all_behaviors.begin(), all_behaviors.end()); + } + } + tests = next_tests; + } + return tests; +} + +// The intention of this test is to do a lot of work to refcounts, while under +// some kind of thread sanitizer. Changing the tag on Dummy from async to sync +// should make Clang's thread sanitizer unhappy, for instance, whereas if the +// tag is async then everything _should_ be fine. +int main() +{ + // Be very careful when increasing these numbers... they can quickly eat up + // your memory and time. + auto tests = build_tests(3, 6, 4); + std::cout << "Found " << tests.size() << " permutations." << std::endl; + + for (auto test : tests) + { + test.run(); + } + + std::cout << "Ran " << tests.size() << " permutations." << std::endl; + return 0; +}