Skip to content

Commit

Permalink
prefer methods to traits
Browse files Browse the repository at this point in the history
  • Loading branch information
mina86 committed Aug 14, 2023
1 parent ea34370 commit a818d09
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 109 deletions.
152 changes: 59 additions & 93 deletions sealable-trie/src/nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,24 @@ impl<'a, P, S> Node<'a, P, S> {
}
}

impl<'a> Node<'a> {
/// Builds raw representation of given node.
///
/// Returns an error if this node is an Extension with a key of invalid
/// length (either empty or too long).
pub fn encode(&self) -> Result<RawNode, ()> {
match self {
Node::Branch { children: [left, right] } => {
Ok(RawNode::branch(*left, *right))
}
Node::Extension { key, child } => {
RawNode::extension(*key, *child).ok_or(())
}
Node::Value { value, child } => Ok(RawNode::value(*value, *child)),
}
}
}

/// Hashes an Extension node with oversized key.
///
/// Normally, this is never called since we should calculate hashes of nodes
Expand All @@ -230,8 +248,8 @@ impl RawNode {
pub fn branch(left: Reference, right: Reference) -> Self {
let mut res = Self([0; 72]);
let (lft, rht) = res.halfs_mut();
*lft = left.encode_raw();
*rht = right.encode_raw();
*lft = left.encode();
*rht = right.encode();
res
}

Expand All @@ -244,20 +262,55 @@ impl RawNode {
let mut res = Self([0; 72]);
let (lft, rht) = res.halfs_mut();
key.encode_into(lft, 0x80)?;
*rht = child.encode_raw();
*rht = child.encode();
Some(res)
}

/// Constructs a Value node with given value hash and child.
pub fn value(value: ValueRef, child: NodeRef) -> Self {
let mut res = Self([0; 72]);
let (lft, rht) = res.halfs_mut();
*lft = Reference::Value(value).encode_raw();
*lft = Reference::Value(value).encode();
lft[0] |= 0x80;
*rht = Reference::Node(child).encode_raw();
*rht = Reference::Node(child).encode();
res
}

/// Decodes raw node into a [`Node`].
///
/// In debug builds panics if `node` holds malformed representation, i.e. if
/// any unused bits (which must be cleared) are set.
pub fn decode(&self) -> Node {
let (left, right) = self.halfs();
let right = Reference::from_raw(right, false);
let tag = self.first() >> 6;
if tag == 0 || tag == 1 {
// Branch
Node::Branch { children: [Reference::from_raw(left, false), right] }
} else if tag == 2 {
// Extension
let key = Slice::decode(left, 0x80).unwrap_or_else(|| {
panic!("Failed decoding raw: {self:?}");
});
Node::Extension { key, child: right }
} else {
// Value
let (num, value) = stdx::split_array_ref::<4, 32, 36>(left);
let num = u32::from_be_bytes(*num);
debug_assert_eq!(
0xC000_0000,
num & !0x2000_0000,
"Failed decoding raw node: {self:?}",
);
let value = ValueRef::new(num & 0x2000_0000 != 0, value.into());
let child = right.try_into().unwrap_or_else(|_| {
debug_assert!(false, "Failed decoding raw node: {self:?}");
NodeRef::new(None, &CryptoHash::DEFAULT)
});
Node::Value { value, child }
}
}

/// Returns the first byte in the raw representation.
fn first(&self) -> u8 { self.0[0] }

Expand Down Expand Up @@ -347,7 +400,7 @@ impl<'a> Reference<'a> {
}

/// Encodes the node reference into the buffer.
fn encode_raw(&self) -> [u8; 36] {
fn encode(&self) -> [u8; 36] {
let (num, hash) = match self {
Self::Node(node) => {
(node.ptr.map_or(0, |ptr| ptr.get()), node.hash)
Expand Down Expand Up @@ -400,38 +453,6 @@ impl<'a> ValueRef<'a, bool> {
}


// =============================================================================
// Trait implementations

impl<'a> From<&'a RawNode> for Node<'a> {
/// Decodes raw node into a [`Node`] assuming that raw bytes are trusted and
/// thus well formed.
///
/// The function is safe even if the bytes aren’t well-formed.
#[inline]
fn from(node: &'a RawNode) -> Self { decode_raw(node) }
}

impl<'a> TryFrom<Node<'a>> for RawNode {
type Error = ();

/// Builds raw representation for given node.
#[inline]
fn try_from(node: Node<'a>) -> Result<Self, Self::Error> {
Self::try_from(&node)
}
}

impl<'a> TryFrom<&Node<'a>> for RawNode {
type Error = ();

/// Builds raw representation for given node.
#[inline]
fn try_from(node: &Node<'a>) -> Result<Self, Self::Error> {
raw_from_node(node).ok_or(())
}
}

// =============================================================================
// PartialEq

Expand Down Expand Up @@ -500,61 +521,6 @@ where
}
}

// =============================================================================
// Conversion functions

/// Decodes raw node into a [`Node`] assuming that raw bytes are trusted and
/// thus well formed.
///
/// In debug builds panics if `node` holds malformed representation, i.e. if any
/// unused bits (which must be cleared) are set.
fn decode_raw<'a>(node: &'a RawNode) -> Node<'a> {
let (left, right) = node.halfs();
let right = Reference::from_raw(right, false);
let tag = node.first() >> 6;
if tag == 0 || tag == 1 {
// Branch
Node::Branch { children: [Reference::from_raw(left, false), right] }
} else if tag == 2 {
// Extension
let key = Slice::decode(left, 0x80).unwrap_or_else(|| {
panic!("Failed decoding raw: {node:?}");
});
Node::Extension { key, child: right }
} else {
// Value
let (num, value) = stdx::split_array_ref::<4, 32, 36>(left);
let num = u32::from_be_bytes(*num);
debug_assert_eq!(
0xC000_0000,
num & !0x2000_0000,
"Failed decoding raw node: {node:?}",
);
let value = ValueRef::new(num & 0x2000_0000 != 0, value.into());
let child = right.try_into().unwrap_or_else(|_| {
debug_assert!(false, "Failed decoding raw node: {node:?}");
NodeRef::new(None, &CryptoHash::DEFAULT)
});
Node::Value { value, child }
}
}

/// Builds raw representation for given node.
///
/// Returns reference to slice of the output buffer holding the representation
/// (node representation used in proofs is variable-length). If the given node
/// cannot be encoded (which happens if it’s an extension with a key whose byte
/// buffer is longer than 34 bytes), returns `None`.
fn raw_from_node<'a>(node: &Node<'a>) -> Option<RawNode> {
match node {
Node::Branch { children: [left, right] } => {
Some(RawNode::branch(*left, *right))
}
Node::Extension { key, child } => RawNode::extension(*key, *child),
Node::Value { value, child } => Some(RawNode::value(*value, *child)),
}
}

// =============================================================================
// Formatting

Expand Down
7 changes: 3 additions & 4 deletions sealable-trie/src/nodes/stress_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ fn stress_test_raw_encoding_round_trip() {
let mut raw = RawNode([0; 72]);
for _ in 0..get_iteration_count() {
gen_random_raw_node(&mut rng, &mut raw.0);
let node = Node::from(&raw);

let node = raw.decode();
// Test RawNode→Node→RawNode round trip conversion.
assert_eq!(Ok(raw), RawNode::try_from(node), "node: {node:?}");
assert_eq!(Ok(raw), node.encode(), "node: {node:?}");
}
}

Expand Down Expand Up @@ -89,7 +88,7 @@ fn stress_test_node_encoding_round_trip() {
let node = gen_random_node(&mut rng, &mut buf);

let raw = super::tests::raw_from_node(&node);
assert_eq!(node, Node::from(&raw), "Failed decoding Raw: {raw:?}");
assert_eq!(node, raw.decode(), "Failed decoding Raw: {raw:?}");
}
}

Expand Down
9 changes: 5 additions & 4 deletions sealable-trie/src/nodes/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,12 @@ const TWO: CryptoHash = CryptoHash([2; 32]);
/// first and last objects aren’t equal. Returns the raw node.
#[track_caller]
pub(super) fn raw_from_node(node: &Node) -> RawNode {
let raw = RawNode::try_from(node)
let raw = node
.encode()
.unwrap_or_else(|()| panic!("Failed encoding node as raw: {node:?}"));
let decoded = Node::from(&raw);
assert_eq!(
*node, decoded,
*node,
raw.decode(),
"Node → RawNode → Node gave different result:\n Raw: {raw:?}"
);
raw
Expand All @@ -47,7 +48,7 @@ pub(super) fn raw_from_node(node: &Node) -> RawNode {
fn check_node_encoding(node: Node, want: [u8; 72], want_hash: &str) {
let raw = raw_from_node(&node);
assert_eq!(want, raw.0, "Unexpected raw representation");
assert_eq!(node, Node::from(&RawNode(want)), "Bad Raw→Node conversion");
assert_eq!(node, RawNode(want).decode(), "Bad Raw→Node conversion");

let want_hash = BASE64_ENGINE.decode(want_hash).unwrap();
let want_hash = <&[u8; 32]>::try_from(want_hash.as_slice()).unwrap();
Expand Down
5 changes: 2 additions & 3 deletions sealable-trie/src/trie.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ impl<A: memory::Allocator> Trie<A> {
let mut node_hash = self.root_hash.clone();
loop {
let node = self.alloc.get(node_ptr.ok_or(Error::Sealed)?);
let node = Node::from(&node);
let node = node.decode();
debug_assert_eq!(node_hash, node.hash());

let child = match node {
Expand Down Expand Up @@ -288,8 +288,7 @@ impl<A: memory::Allocator> Trie<A> {
println!(" (sealed)");
return;
};
let raw = self.alloc.get(ptr);
match Node::from(&raw) {
match self.alloc.get(ptr).decode() {
Node::Branch { children } => {
println!(" Branch");
print_ref(children[0], depth + 2);
Expand Down
4 changes: 2 additions & 2 deletions sealable-trie/src/trie/seal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ impl<'a, A: memory::Allocator> SealContext<'a, A> {
pub(super) fn seal(&mut self, nref: NodeRef) -> Result<bool> {
let ptr = nref.ptr.ok_or(Error::Sealed)?;
let node = self.alloc.get(ptr);
let node = Node::from(&node);
let node = node.decode();
debug_assert_eq!(*nref.hash, node.hash());

let result = match node {
Expand Down Expand Up @@ -167,7 +167,7 @@ fn get_children(node: &RawNode) -> (Option<Ptr>, Option<Ptr>) {
}
}

match Node::from(node) {
match node.decode() {
Node::Branch { children: [lft, rht] } => (get_ptr(lft), get_ptr(rht)),
Node::Extension { child, .. } => (get_ptr(child), None),
Node::Value { child, .. } => (child.ptr, None),
Expand Down
6 changes: 3 additions & 3 deletions sealable-trie/src/trie/set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl<'a, A: memory::Allocator> SetContext<'a, A> {
fn handle(&mut self, nref: NodeRef) -> Result<(Ptr, CryptoHash)> {
let nref = (nref.ptr.ok_or(Error::Sealed)?, nref.hash);
let node = self.wlog.allocator().get(nref.0);
let node = Node::from(&node);
let node = node.decode();
debug_assert_eq!(*nref.1, node.hash());
match node {
Node::Branch { children } => self.handle_branch(nref, children),
Expand Down Expand Up @@ -316,14 +316,14 @@ impl<'a, A: memory::Allocator> SetContext<'a, A> {

/// Sets value of a node cell at given address and returns its hash.
fn set_node(&mut self, ptr: Ptr, node: RawNode) -> (Ptr, CryptoHash) {
let hash = Node::from(&node).hash();
let hash = node.decode().hash();
self.wlog.set(ptr, node);
(ptr, hash)
}

/// Allocates a new node and sets it to given value.
fn alloc_node(&mut self, node: RawNode) -> Result<(Ptr, CryptoHash)> {
let hash = Node::from(&node).hash();
let hash = node.decode().hash();
let ptr = self.wlog.alloc(node)?;
Ok((ptr, hash))
}
Expand Down

0 comments on commit a818d09

Please sign in to comment.