diff --git a/src/CompleteMerkle.sol b/src/CompleteMerkle.sol index 8653215..62c56fe 100644 --- a/src/CompleteMerkle.sol +++ b/src/CompleteMerkle.sol @@ -1,54 +1,32 @@ // SPDX-License-Identifier: MIT pragma solidity ^0.8.4; -import "./common/MurkyBase.sol"; - /// @notice Nascent, simple, kinda efficient (and improving!) Merkle proof generator and verifier /// @author dmfxyz /// @dev Note Generic Merkle Tree -contract CompleteMerkle is MurkyBase { +contract CompleteMerkle { /** * * HASHING FUNCTION * * */ - - /// ascending sort and concat prior to hashing - function hashLeafPairs(bytes32 left, bytes32 right) public pure override returns (bytes32 _hash) { - assembly { - switch lt(left, right) - case 0 { - mstore(0x0, right) - mstore(0x20, left) - } - default { - mstore(0x0, left) - mstore(0x20, right) - } - _hash := keccak256(0x0, 0x40) - } - } - - function getTree(bytes32[] memory data) public pure returns (bytes32[] memory) { - require(data.length > 1, 'wont generate root for single leaf'); + function initTree(bytes32[] memory data) private pure returns (bytes32[] memory) { + require(data.length > 1, "wont generate root for single leaf"); bytes32[] memory tree = new bytes32[](2 * data.length - 1); assembly { let dlen := mload(data) let titer := add(sub(mload(tree), dlen), 1) - for {let i := add(data, mul(dlen, 0x20))} gt(i, data) { i := sub(i, 0x20)} { + for { let i := add(data, mul(dlen, 0x20)) } gt(i, data) { i := sub(i, 0x20) } { mstore(add(tree, mul(titer, 0x20)), mload(i)) titer := add(titer, 1) } - } - // for (uint256 i = 0; i < data.length; ++i) { - // tree[tree.length - i - 1] = data[i]; - // } + } return tree; } - function getRoot(bytes32[] memory data) public pure override returns (bytes32) { - bytes32[] memory tree = getTree(data); + function buildTree(bytes32[] memory data) public pure returns (bytes32[] memory) { + bytes32[] memory tree = initTree(data); assembly { function hash_leafs(left, right) -> _hash { switch lt(left, right) @@ -62,7 +40,7 @@ contract CompleteMerkle is MurkyBase { } _hash := keccak256(0x0, 0x40) } - for {let i := mload(tree)} gt(i, 1) {i := sub(i, 2)} { + for { let i := mload(tree) } gt(i, 1) { i := sub(i, 2) } { // TODO: clean all this up, mainly broken out for early understanding and debugging let left := mload(add(tree, mul(sub(i, 1), 0x20))) let right := mload(add(tree, mul(i, 0x20))) @@ -71,12 +49,71 @@ contract CompleteMerkle is MurkyBase { mstore(posToWrite, hash_leafs(left, right)) } } - // for (uint256 i = tree.length - 1; i > 0; i -= 2) { - // uint256 posToWrite = (i - 1) / 2; - // tree[posToWrite] = hashLeafPairs(tree[i - 1], tree[i]); - // } + return tree; + } + + function getRoot(bytes32[] memory data) public pure returns (bytes32) { + bytes32[] memory tree = buildTree(data); return tree[0]; } + function verifyProof(bytes32 root, bytes32[] memory proof, bytes32 valueToProve) external pure returns (bool) { + assembly { + function hash_leafs(left, right) -> _hash { + switch lt(left, right) + case 0 { + mstore(0x0, right) + mstore(0x20, left) + } + default { + mstore(0x0, left) + mstore(0x20, right) + } + _hash := keccak256(0x0, 0x40) + } + let roll := mload(0x40) + mstore(roll, valueToProve) + let len := mload(proof) + for { let i := 0 } lt(i, len) { i := add(i, 1) } { + mstore(roll, hash_leafs(mload(roll), mload(add(add(proof, 0x20), mul(i, 0x20))))) + } + mstore(roll, eq(mload(roll), root)) + return(roll, 0x20) + } + } + + function getProof(bytes32[] memory data, uint256 index) public pure returns (bytes32[] memory) { + require(index < data.length); + bytes32[] memory tree = buildTree(data); + assembly { + let iter := sub(sub(mload(tree), index), 0x1) + let ptr := mload(0x40) + mstore(ptr, 0x20) + let proofSizePtr := add(ptr, 0x20) + let proofIndexPtr := proofSizePtr + for {} eq(0x0, 0x0) {} { + // WHile true + switch eq(iter, 0) + case 1 { break } + mstore(proofSizePtr, add(mload(proofSizePtr), 0x1)) + proofIndexPtr := add(proofIndexPtr, 0x20) + switch and(iter, 0x1) + case 0x1 { + // iter is ODD + let sibling := mload(add(add(tree, 0x20), mul(add(iter, 1), 0x20))) + mstore(proofIndexPtr, sibling) + iter := div(iter, 2) + } + default { + // iter is EVEN + let sibling := mload(add(add(tree, 0x20), mul(sub(iter, 1), 0x20))) + mstore(proofIndexPtr, sibling) + iter := div(sub(iter, 1), 2) + } + } + mstore(0x40, add(proofIndexPtr, 0x20)) + return(ptr, add(0x20, sub(add(proofIndexPtr, 0x20), proofSizePtr))) + } + } } diff --git a/src/common/MurkyBase.sol b/src/common/MurkyBase.sol index e3310af..06d6f37 100644 --- a/src/common/MurkyBase.sol +++ b/src/common/MurkyBase.sol @@ -2,45 +2,40 @@ pragma solidity ^0.8.4; abstract contract MurkyBase { - /** - * - * CONSTRUCTOR * - * - */ + /*************** + * CONSTRUCTOR * + ***************/ constructor() {} - /** - * - * VIRTUAL HASHING FUNCTIONS * - * - */ + /******************** + * VIRTUAL HASHING FUNCTIONS * + ********************/ function hashLeafPairs(bytes32 left, bytes32 right) public pure virtual returns (bytes32 _hash); - /** - * - * PROOF VERIFICATION * - * - */ + + /********************** + * PROOF VERIFICATION * + **********************/ + function verifyProof(bytes32 root, bytes32[] memory proof, bytes32 valueToProve) external pure returns (bool) { // proof length must be less than max array size bytes32 rollingHash = valueToProve; uint256 length = proof.length; unchecked { - for (uint256 i = 0; i < length; ++i) { + for(uint i = 0; i < length; ++i){ rollingHash = hashLeafPairs(rollingHash, proof[i]); } } return root == rollingHash; } - /** - * - * PROOF GENERATION * - * - */ - function getRoot(bytes32[] memory data) public pure virtual returns (bytes32) { + /******************** + * PROOF GENERATION * + ********************/ + + function getRoot(bytes32[] memory data) public pure returns (bytes32) { require(data.length > 1, "won't generate root for single leaf"); - while (data.length > 1) { + while(data.length > 1) { data = hashLevel(data); } return data[0]; @@ -48,21 +43,23 @@ abstract contract MurkyBase { function getProof(bytes32[] memory data, uint256 node) public pure returns (bytes32[] memory) { require(data.length > 1, "won't generate proof for single leaf"); - // The size of the proof is equal to the ceiling of log2(numLeaves) + // The size of the proof is equal to the ceiling of log2(numLeaves) bytes32[] memory result = new bytes32[](log2ceilBitMagic(data.length)); uint256 pos = 0; // Two overflow risks: node, pos // node: max array size is 2**256-1. Largest index in the array will be 1 less than that. Also, - // for dynamic arrays, size is limited to 2**64-1 + // for dynamic arrays, size is limited to 2**64-1 // pos: pos is bounded by log2(data.length), which should be less than type(uint256).max - while (data.length > 1) { + while(data.length > 1) { unchecked { - if (node & 0x1 == 1) { + if(node & 0x1 == 1) { result[pos] = data[node - 1]; - } else if (node + 1 == data.length) { - result[pos] = bytes32(0); - } else { + } + else if (node + 1 == data.length) { + result[pos] = bytes32(0); + } + else { result[pos] = data[node + 1]; } ++pos; @@ -79,38 +76,36 @@ abstract contract MurkyBase { // Function is private, and all internal callers check that data.length >=2. // Underflow is not possible as lowest possible value for data/result index is 1 - // overflow should be safe as length is / 2 always. + // overflow should be safe as length is / 2 always. unchecked { uint256 length = data.length; - if (length & 0x1 == 1) { + if (length & 0x1 == 1){ result = new bytes32[](length / 2 + 1); - result[result.length - 1] = data[length - 1]; + result[result.length - 1] = hashLeafPairs(data[length - 1], bytes32(0)); } else { result = new bytes32[](length / 2); - } - // pos is upper bounded by data.length / 2, so safe even if array is at max size + } + // pos is upper bounded by data.length / 2, so safe even if array is at max size uint256 pos = 0; - for (uint256 i = 0; i < length - 1; i += 2) { - result[pos] = hashLeafPairs(data[i], data[i + 1]); + for (uint256 i = 0; i < length-1; i+=2){ + result[pos] = hashLeafPairs(data[i], data[i+1]); ++pos; } } return result; } - /** - * - * MATH "LIBRARY" * - * - */ - + /****************** + * MATH "LIBRARY" * + ******************/ + /// @dev Note that x is assumed > 0 function log2ceil(uint256 x) public pure returns (uint256) { uint256 ceil = 0; - uint256 pOf2; + uint pOf2; // If x is a power of 2, then this function will return a ceiling // that is 1 greater than the actual ceiling. So we need to check if - // x is a power of 2, and subtract one from ceil if so. + // x is a power of 2, and subtract one from ceil if so. assembly { // we check by seeing if x == (~x + 1) & x. This applies a mask // to find the lowest set bit of x and then checks it for equality @@ -131,11 +126,11 @@ abstract contract MurkyBase { // we do some assembly magic to treat the bool as an integer later on pOf2 := eq(and(add(not(x), 1), x), x) } - + // if x == type(uint256).max, than ceil is capped at 256 // if x == 0, then pO2 == 0, so ceil won't underflow unchecked { - while (x > 0) { + while( x > 0) { x >>= 1; ceil++; } @@ -146,41 +141,41 @@ abstract contract MurkyBase { /// Original bitmagic adapted from https://github.com/paulrberg/prb-math/blob/main/contracts/PRBMath.sol /// @dev Note that x assumed > 1 - function log2ceilBitMagic(uint256 x) public pure returns (uint256) { + function log2ceilBitMagic(uint256 x) public pure returns (uint256){ if (x <= 1) { return 0; } uint256 msb = 0; uint256 _x = x; - if (x >= 2 ** 128) { + if (x >= 2**128) { x >>= 128; msb += 128; } - if (x >= 2 ** 64) { + if (x >= 2**64) { x >>= 64; msb += 64; } - if (x >= 2 ** 32) { + if (x >= 2**32) { x >>= 32; msb += 32; } - if (x >= 2 ** 16) { + if (x >= 2**16) { x >>= 16; msb += 16; } - if (x >= 2 ** 8) { + if (x >= 2**8) { x >>= 8; msb += 8; } - if (x >= 2 ** 4) { + if (x >= 2**4) { x >>= 4; msb += 4; } - if (x >= 2 ** 2) { + if (x >= 2**2) { x >>= 2; msb += 2; } - if (x >= 2 ** 1) { + if (x >= 2**1) { msb += 1; } @@ -191,4 +186,4 @@ abstract contract MurkyBase { return msb + 1; } } -} +} \ No newline at end of file