Skip to content

Commit

Permalink
initial implementation of CompleteMerkle.sol
Browse files Browse the repository at this point in the history
  • Loading branch information
dmfxyz committed Mar 22, 2024
1 parent 374a161 commit 4b1ab18
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 91 deletions.
105 changes: 71 additions & 34 deletions src/CompleteMerkle.sol
Original file line number Diff line number Diff line change
@@ -1,54 +1,32 @@
// SPDX-License-Identifier: MIT
pragma solidity ^0.8.4;

import "./common/MurkyBase.sol";

/// @notice Nascent, simple, kinda efficient (and improving!) Merkle proof generator and verifier
/// @author dmfxyz
/// @dev Note Generic Merkle Tree
contract CompleteMerkle is MurkyBase {
contract CompleteMerkle {
/**
*
* HASHING FUNCTION *
*
*/

/// ascending sort and concat prior to hashing
function hashLeafPairs(bytes32 left, bytes32 right) public pure override returns (bytes32 _hash) {
assembly {
switch lt(left, right)
case 0 {
mstore(0x0, right)
mstore(0x20, left)
}
default {
mstore(0x0, left)
mstore(0x20, right)
}
_hash := keccak256(0x0, 0x40)
}
}

function getTree(bytes32[] memory data) public pure returns (bytes32[] memory) {
require(data.length > 1, 'wont generate root for single leaf');
function initTree(bytes32[] memory data) private pure returns (bytes32[] memory) {
require(data.length > 1, "wont generate root for single leaf");

bytes32[] memory tree = new bytes32[](2 * data.length - 1);
assembly {
let dlen := mload(data)
let titer := add(sub(mload(tree), dlen), 1)
for {let i := add(data, mul(dlen, 0x20))} gt(i, data) { i := sub(i, 0x20)} {
for { let i := add(data, mul(dlen, 0x20)) } gt(i, data) { i := sub(i, 0x20) } {
mstore(add(tree, mul(titer, 0x20)), mload(i))
titer := add(titer, 1)
}
}
// for (uint256 i = 0; i < data.length; ++i) {
// tree[tree.length - i - 1] = data[i];
// }
}
return tree;
}

function getRoot(bytes32[] memory data) public pure override returns (bytes32) {
bytes32[] memory tree = getTree(data);
function buildTree(bytes32[] memory data) public pure returns (bytes32[] memory) {
bytes32[] memory tree = initTree(data);
assembly {
function hash_leafs(left, right) -> _hash {
switch lt(left, right)
Expand All @@ -62,7 +40,7 @@ contract CompleteMerkle is MurkyBase {
}
_hash := keccak256(0x0, 0x40)
}
for {let i := mload(tree)} gt(i, 1) {i := sub(i, 2)} {
for { let i := mload(tree) } gt(i, 1) { i := sub(i, 2) } {
// TODO: clean all this up, mainly broken out for early understanding and debugging
let left := mload(add(tree, mul(sub(i, 1), 0x20)))
let right := mload(add(tree, mul(i, 0x20)))
Expand All @@ -71,12 +49,71 @@ contract CompleteMerkle is MurkyBase {
mstore(posToWrite, hash_leafs(left, right))
}
}
// for (uint256 i = tree.length - 1; i > 0; i -= 2) {
// uint256 posToWrite = (i - 1) / 2;
// tree[posToWrite] = hashLeafPairs(tree[i - 1], tree[i]);
// }
return tree;
}

function getRoot(bytes32[] memory data) public pure returns (bytes32) {
bytes32[] memory tree = buildTree(data);
return tree[0];
}

function verifyProof(bytes32 root, bytes32[] memory proof, bytes32 valueToProve) external pure returns (bool) {
assembly {
function hash_leafs(left, right) -> _hash {
switch lt(left, right)
case 0 {
mstore(0x0, right)
mstore(0x20, left)
}
default {
mstore(0x0, left)
mstore(0x20, right)
}
_hash := keccak256(0x0, 0x40)
}
let roll := mload(0x40)
mstore(roll, valueToProve)
let len := mload(proof)
for { let i := 0 } lt(i, len) { i := add(i, 1) } {
mstore(roll, hash_leafs(mload(roll), mload(add(add(proof, 0x20), mul(i, 0x20)))))
}
mstore(roll, eq(mload(roll), root))
return(roll, 0x20)
}
}

function getProof(bytes32[] memory data, uint256 index) public pure returns (bytes32[] memory) {
require(index < data.length);
bytes32[] memory tree = buildTree(data);

assembly {
let iter := sub(sub(mload(tree), index), 0x1)
let ptr := mload(0x40)
mstore(ptr, 0x20)
let proofSizePtr := add(ptr, 0x20)
let proofIndexPtr := proofSizePtr
for {} eq(0x0, 0x0) {} {
// WHile true
switch eq(iter, 0)
case 1 { break }
mstore(proofSizePtr, add(mload(proofSizePtr), 0x1))
proofIndexPtr := add(proofIndexPtr, 0x20)
switch and(iter, 0x1)
case 0x1 {
// iter is ODD
let sibling := mload(add(add(tree, 0x20), mul(add(iter, 1), 0x20)))
mstore(proofIndexPtr, sibling)
iter := div(iter, 2)
}
default {
// iter is EVEN
let sibling := mload(add(add(tree, 0x20), mul(sub(iter, 1), 0x20)))
mstore(proofIndexPtr, sibling)
iter := div(sub(iter, 1), 2)
}
}
mstore(0x40, add(proofIndexPtr, 0x20))
return(ptr, add(0x20, sub(add(proofIndexPtr, 0x20), proofSizePtr)))
}
}
}
109 changes: 52 additions & 57 deletions src/common/MurkyBase.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,64 @@
pragma solidity ^0.8.4;

abstract contract MurkyBase {
/**
*
* CONSTRUCTOR *
*
*/
/***************
* CONSTRUCTOR *
***************/
constructor() {}

/**
*
* VIRTUAL HASHING FUNCTIONS *
*
*/
/********************
* VIRTUAL HASHING FUNCTIONS *
********************/
function hashLeafPairs(bytes32 left, bytes32 right) public pure virtual returns (bytes32 _hash);

/**
*
* PROOF VERIFICATION *
*
*/

/**********************
* PROOF VERIFICATION *
**********************/

function verifyProof(bytes32 root, bytes32[] memory proof, bytes32 valueToProve) external pure returns (bool) {
// proof length must be less than max array size
bytes32 rollingHash = valueToProve;
uint256 length = proof.length;
unchecked {
for (uint256 i = 0; i < length; ++i) {
for(uint i = 0; i < length; ++i){
rollingHash = hashLeafPairs(rollingHash, proof[i]);
}
}
return root == rollingHash;
}

/**
*
* PROOF GENERATION *
*
*/
function getRoot(bytes32[] memory data) public pure virtual returns (bytes32) {
/********************
* PROOF GENERATION *
********************/

function getRoot(bytes32[] memory data) public pure returns (bytes32) {
require(data.length > 1, "won't generate root for single leaf");
while (data.length > 1) {
while(data.length > 1) {
data = hashLevel(data);
}
return data[0];
}

function getProof(bytes32[] memory data, uint256 node) public pure returns (bytes32[] memory) {
require(data.length > 1, "won't generate proof for single leaf");
// The size of the proof is equal to the ceiling of log2(numLeaves)
// The size of the proof is equal to the ceiling of log2(numLeaves)
bytes32[] memory result = new bytes32[](log2ceilBitMagic(data.length));
uint256 pos = 0;

// Two overflow risks: node, pos
// node: max array size is 2**256-1. Largest index in the array will be 1 less than that. Also,
// for dynamic arrays, size is limited to 2**64-1
// for dynamic arrays, size is limited to 2**64-1
// pos: pos is bounded by log2(data.length), which should be less than type(uint256).max
while (data.length > 1) {
while(data.length > 1) {
unchecked {
if (node & 0x1 == 1) {
if(node & 0x1 == 1) {
result[pos] = data[node - 1];
} else if (node + 1 == data.length) {
result[pos] = bytes32(0);
} else {
}
else if (node + 1 == data.length) {
result[pos] = bytes32(0);
}
else {
result[pos] = data[node + 1];
}
++pos;
Expand All @@ -79,38 +76,36 @@ abstract contract MurkyBase {

// Function is private, and all internal callers check that data.length >=2.
// Underflow is not possible as lowest possible value for data/result index is 1
// overflow should be safe as length is / 2 always.
// overflow should be safe as length is / 2 always.
unchecked {
uint256 length = data.length;
if (length & 0x1 == 1) {
if (length & 0x1 == 1){
result = new bytes32[](length / 2 + 1);
result[result.length - 1] = data[length - 1];
result[result.length - 1] = hashLeafPairs(data[length - 1], bytes32(0));
} else {
result = new bytes32[](length / 2);
}
// pos is upper bounded by data.length / 2, so safe even if array is at max size
}
// pos is upper bounded by data.length / 2, so safe even if array is at max size
uint256 pos = 0;
for (uint256 i = 0; i < length - 1; i += 2) {
result[pos] = hashLeafPairs(data[i], data[i + 1]);
for (uint256 i = 0; i < length-1; i+=2){
result[pos] = hashLeafPairs(data[i], data[i+1]);
++pos;
}
}
return result;
}

/**
*
* MATH "LIBRARY" *
*
*/

/******************
* MATH "LIBRARY" *
******************/

/// @dev Note that x is assumed > 0
function log2ceil(uint256 x) public pure returns (uint256) {
uint256 ceil = 0;
uint256 pOf2;
uint pOf2;
// If x is a power of 2, then this function will return a ceiling
// that is 1 greater than the actual ceiling. So we need to check if
// x is a power of 2, and subtract one from ceil if so.
// x is a power of 2, and subtract one from ceil if so.
assembly {
// we check by seeing if x == (~x + 1) & x. This applies a mask
// to find the lowest set bit of x and then checks it for equality
Expand All @@ -131,11 +126,11 @@ abstract contract MurkyBase {
// we do some assembly magic to treat the bool as an integer later on
pOf2 := eq(and(add(not(x), 1), x), x)
}

// if x == type(uint256).max, than ceil is capped at 256
// if x == 0, then pO2 == 0, so ceil won't underflow
unchecked {
while (x > 0) {
while( x > 0) {
x >>= 1;
ceil++;
}
Expand All @@ -146,41 +141,41 @@ abstract contract MurkyBase {

/// Original bitmagic adapted from https://github.com/paulrberg/prb-math/blob/main/contracts/PRBMath.sol
/// @dev Note that x assumed > 1
function log2ceilBitMagic(uint256 x) public pure returns (uint256) {
function log2ceilBitMagic(uint256 x) public pure returns (uint256){
if (x <= 1) {
return 0;
}
uint256 msb = 0;
uint256 _x = x;
if (x >= 2 ** 128) {
if (x >= 2**128) {
x >>= 128;
msb += 128;
}
if (x >= 2 ** 64) {
if (x >= 2**64) {
x >>= 64;
msb += 64;
}
if (x >= 2 ** 32) {
if (x >= 2**32) {
x >>= 32;
msb += 32;
}
if (x >= 2 ** 16) {
if (x >= 2**16) {
x >>= 16;
msb += 16;
}
if (x >= 2 ** 8) {
if (x >= 2**8) {
x >>= 8;
msb += 8;
}
if (x >= 2 ** 4) {
if (x >= 2**4) {
x >>= 4;
msb += 4;
}
if (x >= 2 ** 2) {
if (x >= 2**2) {
x >>= 2;
msb += 2;
}
if (x >= 2 ** 1) {
if (x >= 2**1) {
msb += 1;
}

Expand All @@ -191,4 +186,4 @@ abstract contract MurkyBase {
return msb + 1;
}
}
}
}

0 comments on commit 4b1ab18

Please sign in to comment.