diff --git a/src/BucketDLL.sol b/src/BucketDLL.sol index 911d495b..42154136 100644 --- a/src/BucketDLL.sol +++ b/src/BucketDLL.sol @@ -1,8 +1,12 @@ // SPDX-License-Identifier: AGPL-3.0-only pragma solidity ^0.8.0; +/// @title BucketDLL +/// @author Morpho Labs +/// @custom:contact security@morpho.xyz +/// @notice The doubly linked list used in logarithmic buckets. library BucketDLL { - /// STRUCTS /// + /* STRUCTS */ struct Account { address prev; @@ -13,78 +17,59 @@ library BucketDLL { mapping(address => Account) accounts; } - /// INTERNAL /// + /* INTERNAL */ - /// @notice Returns the address at the head of the `_list`. - /// @param _list The list from which to get the head. - /// @return The address of the head. - function getHead(List storage _list) internal view returns (address) { - return _list.accounts[address(0)].next; - } - - /// @notice Returns the address at the tail of the `_list`. - /// @param _list The list from which to get the tail. - /// @return The address of the tail. - function getTail(List storage _list) internal view returns (address) { - return _list.accounts[address(0)].prev; - } - - /// @notice Returns the next id address from the current `_id`. - /// @param _list The list to search in. - /// @param _id The address of the current account. + /// @notice Returns the next id address from the current `id`. + /// @dev Pass the address 0 to get the head of the list. + /// @param list The list to search in. + /// @param id The address of the current account. /// @return The address of the next account. - function getNext(List storage _list, address _id) internal view returns (address) { - return _list.accounts[_id].next; - } - - /// @notice Returns the previous id address from the current `_id`. - /// @param _list The list to search in. - /// @param _id The address of the current account. - /// @return The address of the previous account. - function getPrev(List storage _list, address _id) internal view returns (address) { - return _list.accounts[_id].prev; + function getNext(List storage list, address id) internal view returns (address) { + return list.accounts[id].next; } - /// @notice Removes an account of the `_list`. - /// @dev This function should not be called with `_id` equal to address 0. - /// @param _list The list to search in. - /// @param _id The address of the account. + /// @notice Removes an account of the `list`. + /// @dev This function should not be called with `id` equal to address 0. + /// @dev This function should not be called with an `_id` that is not in the list. + /// @param list The list to search in. + /// @param id The address of the account. /// @return Whether the bucket is empty after removal. - function remove(List storage _list, address _id) internal returns (bool) { - Account memory account = _list.accounts[_id]; + function remove(List storage list, address id) internal returns (bool) { + Account memory account = list.accounts[id]; address prev = account.prev; address next = account.next; - _list.accounts[prev].next = next; - _list.accounts[next].prev = prev; + list.accounts[prev].next = next; + list.accounts[next].prev = prev; - delete _list.accounts[_id]; + delete list.accounts[id]; return (prev == address(0) && next == address(0)); } - /// @notice Inserts an account in the `_list`. - /// @dev This function should not be called with `_id` equal to address 0. - /// @param _list The list to search in. - /// @param _id The address of the account. - /// @param _head Tells whether to insert at the head or at the tail of the list. + /// @notice Inserts an account in the `list`. + /// @dev This function should not be called with `id` equal to address 0. + /// @dev This function should not be called with an `id` that is already in the list. + /// @param list The list to search in. + /// @param id The address of the account. + /// @param atHead Tells whether to insert at the head or at the tail of the list. /// @return Whether the bucket was empty before insertion. function insert( - List storage _list, - address _id, - bool _head + List storage list, + address id, + bool atHead ) internal returns (bool) { - if (_head) { - address head = _list.accounts[address(0)].next; - _list.accounts[address(0)].next = _id; - _list.accounts[head].prev = _id; - _list.accounts[_id].next = head; + if (atHead) { + address head = list.accounts[address(0)].next; + list.accounts[address(0)].next = id; + list.accounts[head].prev = id; + list.accounts[id].next = head; return head == address(0); } else { - address tail = _list.accounts[address(0)].prev; - _list.accounts[address(0)].prev = _id; - _list.accounts[tail].next = _id; - _list.accounts[_id].prev = tail; + address tail = list.accounts[address(0)].prev; + list.accounts[address(0)].prev = id; + list.accounts[tail].next = id; + list.accounts[id].prev = tail; return tail == address(0); } } diff --git a/src/LogarithmicBuckets.sol b/src/LogarithmicBuckets.sol index dcb51f73..248c7a07 100644 --- a/src/LogarithmicBuckets.sol +++ b/src/LogarithmicBuckets.sol @@ -3,6 +3,10 @@ pragma solidity ^0.8.0; import "./BucketDLL.sol"; +/// @title LogarithmicBuckets +/// @author Morpho Labs +/// @custom:contact security@morpho.xyz +/// @notice The logarithmic buckets data-structure. library LogarithmicBuckets { using BucketDLL for BucketDLL.List; @@ -12,7 +16,7 @@ library LogarithmicBuckets { uint256 bucketsMask; } - /// ERRORS /// + /* ERRORS */ /// @notice Thrown when the address is zero at insertion. error ZeroAddress(); @@ -20,96 +24,100 @@ library LogarithmicBuckets { /// @notice Thrown when 0 value is inserted. error ZeroValue(); - /// INTERNAL /// + /* INTERNAL */ - /// @notice Updates an account in the `_buckets`. - /// @param _buckets The buckets to update. - /// @param _id The address of the account. - /// @param _newValue The new value of the account. - /// @param _head Indicates whether to insert the new values at the head or at the tail of the buckets list. + /// @notice Updates an account in the `buckets`. + /// @param buckets The buckets to update. + /// @param id The address of the account. + /// @param newValue The new value of the account. + /// @param head Indicates whether to insert the new values at the head or at the tail of the buckets list. function update( - Buckets storage _buckets, - address _id, - uint256 _newValue, - bool _head + Buckets storage buckets, + address id, + uint256 newValue, + bool head ) internal { - if (_id == address(0)) revert ZeroAddress(); - uint256 value = _buckets.valueOf[_id]; - _buckets.valueOf[_id] = _newValue; + if (id == address(0)) revert ZeroAddress(); + uint256 value = buckets.valueOf[id]; + buckets.valueOf[id] = newValue; if (value == 0) { - if (_newValue == 0) revert ZeroValue(); - _insert(_buckets, _id, computeBucket(_newValue), _head); + if (newValue == 0) revert ZeroValue(); + // `highestSetBit` is used to compute the bucket associated with `newValue`. + _insert(buckets, id, highestSetBit(newValue), head); return; } - uint256 currentBucket = computeBucket(value); - if (_newValue == 0) { - _remove(_buckets, _id, currentBucket); + // `highestSetBit` is used to compute the bucket associated with `value`. + uint256 currentBucket = highestSetBit(value); + if (newValue == 0) { + _remove(buckets, id, currentBucket); return; } - uint256 newBucket = computeBucket(_newValue); + // `highestSetBit` is used to compute the bucket associated with `newValue`. + uint256 newBucket = highestSetBit(newValue); if (newBucket != currentBucket) { - _remove(_buckets, _id, currentBucket); - _insert(_buckets, _id, newBucket, _head); + _remove(buckets, id, currentBucket); + _insert(buckets, id, newBucket, head); } } - /// @notice Returns the address in `_buckets` that is a candidate for matching the value `_value`. - /// @param _buckets The buckets to get the head. - /// @param _value The value to match. + /// @notice Returns the address in `buckets` that is a candidate for matching the value `value`. + /// @param buckets The buckets to get the head. + /// @param value The value to match. /// @return The address of the head. - function getMatch(Buckets storage _buckets, uint256 _value) internal view returns (address) { - uint256 bucketsMask = _buckets.bucketsMask; + function getMatch(Buckets storage buckets, uint256 value) internal view returns (address) { + uint256 bucketsMask = buckets.bucketsMask; if (bucketsMask == 0) return address(0); - uint256 lowerMask = setLowerBits(_value); - uint256 next = nextBucket(lowerMask, bucketsMask); + uint256 next = nextBucket(value, bucketsMask); + if (next != 0) return buckets.buckets[next].getNext(address(0)); - if (next != 0) return _buckets.buckets[next].getHead(); - - uint256 prev = prevBucket(lowerMask, bucketsMask); - - return _buckets.buckets[prev].getHead(); + // `highestSetBit` is used to compute the highest non-empty bucket. + // Knowing that `next` == 0, it is also the highest previous non-empty bucket. + uint256 prev = highestSetBit(bucketsMask); + return buckets.buckets[prev].getNext(address(0)); } - /// PRIVATE /// + /* PRIVATE */ - /// @notice Removes an account in the `_buckets`. + /// @notice Removes an account in the `buckets`. /// @dev Does not update the value. - /// @param _buckets The buckets to modify. - /// @param _id The address of the account to remove. - /// @param _bucket The mask of the bucket where to remove. + /// @param buckets The buckets to modify. + /// @param id The address of the account to remove. + /// @param bucket The mask of the bucket where to remove. function _remove( - Buckets storage _buckets, - address _id, - uint256 _bucket + Buckets storage buckets, + address id, + uint256 bucket ) private { - if (_buckets.buckets[_bucket].remove(_id)) _buckets.bucketsMask &= ~_bucket; + if (buckets.buckets[bucket].remove(id)) buckets.bucketsMask &= ~bucket; } - /// @notice Inserts an account in the `_buckets`. - /// @dev Expects that `_id` != 0. + /// @notice Inserts an account in the `buckets`. + /// @dev Expects that `id` != 0. /// @dev Does not update the value. - /// @param _buckets The buckets to modify. - /// @param _id The address of the account to update. - /// @param _bucket The mask of the bucket where to insert. - /// @param _head Whether to insert at the head or at the tail of the list. + /// @param buckets The buckets to modify. + /// @param id The address of the account to update. + /// @param bucket The mask of the bucket where to insert. + /// @param head Whether to insert at the head or at the tail of the list. function _insert( - Buckets storage _buckets, - address _id, - uint256 _bucket, - bool _head + Buckets storage buckets, + address id, + uint256 bucket, + bool head ) private { - if (_buckets.buckets[_bucket].insert(_id, _head)) _buckets.bucketsMask |= _bucket; + if (buckets.buckets[bucket].insert(id, head)) buckets.bucketsMask |= bucket; } - /// PURE HELPERS /// + /* PURE HELPERS */ - /// @notice Returns the bucket in which the given value would fall. - function computeBucket(uint256 _value) internal pure returns (uint256) { - uint256 lowerMask = setLowerBits(_value); + /// @notice Returns the highest set bit. + /// @dev Used to compute the bucket associated to a given `value`. + /// @dev Used to compute the highest non empty bucket given the `bucketsMask`. + function highestSetBit(uint256 value) internal pure returns (uint256) { + uint256 lowerMask = setLowerBits(value); return lowerMask ^ (lowerMask >> 1); } @@ -128,23 +136,13 @@ library LogarithmicBuckets { } } - /// @notice Returns the following bucket which contains greater values. + /// @notice Returns the lowest non-empty bucket containing larger values. /// @dev The bucket returned is the lowest that is in `bucketsMask` and not in `lowerMask`. - function nextBucket(uint256 lowerMask, uint256 bucketsMask) - internal - pure - returns (uint256 bucket) - { + function nextBucket(uint256 value, uint256 bucketsMask) internal pure returns (uint256 bucket) { + uint256 lowerMask = setLowerBits(value); assembly { let higherBucketsMask := and(not(lowerMask), bucketsMask) bucket := and(higherBucketsMask, add(not(higherBucketsMask), 1)) } } - - /// @notice Returns the preceding bucket which contains smaller values. - /// @dev The bucket returned is the highest that is in both `bucketsMask` and `lowerMask`. - function prevBucket(uint256 lowerMask, uint256 bucketsMask) internal pure returns (uint256) { - uint256 lowerBucketsMask = setLowerBits(lowerMask & bucketsMask); - return lowerBucketsMask ^ (lowerBucketsMask >> 1); - } } diff --git a/test/TestBucketDLL.t.sol b/test/TestBucketDLL.t.sol index ce1496ca..5019fd9e 100644 --- a/test/TestBucketDLL.t.sol +++ b/test/TestBucketDLL.t.sol @@ -2,10 +2,10 @@ pragma solidity ^0.8.0; import "forge-std/Test.sol"; -import "src/BucketDLL.sol"; +import "./mocks/BucketDLLMock.sol"; contract TestBucketDLL is Test { - using BucketDLL for BucketDLL.List; + using BucketDLLMock for BucketDLL.List; uint256 internal numberOfAccounts = 50; address[] public accounts; diff --git a/test/TestLogarithmicBuckets.t.sol b/test/TestLogarithmicBuckets.t.sol index 747b82b2..2d9c5038 100644 --- a/test/TestLogarithmicBuckets.t.sol +++ b/test/TestLogarithmicBuckets.t.sol @@ -5,7 +5,7 @@ import "forge-std/Test.sol"; import "./mocks/LogarithmicBucketsMock.sol"; contract TestLogarithmicBuckets is LogarithmicBucketsMock, Test { - using BucketDLL for BucketDLL.List; + using BucketDLLMock for BucketDLL.List; using LogarithmicBuckets for LogarithmicBuckets.Buckets; uint256 public accountsLength = 50; @@ -125,7 +125,7 @@ contract TestProveLogarithmicBuckets is LogarithmicBucketsMock, Test { } function testProveComputeBucket(uint256 _value) public { - uint256 bucket = LogarithmicBuckets.computeBucket(_value); + uint256 bucket = LogarithmicBuckets.highestSetBit(_value); unchecked { // cross-check that bucket == 2^{floor(log_2 value)}, or 0 if value == 0 assertTrue(bucket == 0 || isPowerOfTwo(bucket)); @@ -135,11 +135,12 @@ contract TestProveLogarithmicBuckets is LogarithmicBucketsMock, Test { } function testProveNextBucket(uint256 _value) public { - uint256 curr = LogarithmicBuckets.computeBucket(_value); + uint256 curr = LogarithmicBuckets.highestSetBit(_value); uint256 next = nextBucketValue(_value); uint256 bucketsMask = buckets.bucketsMask; - // check that `next` is a strictly higer non-empty bucket, or zero + // Check that `next` is a power of two or zero. assertTrue(next == 0 || isPowerOfTwo(next)); + // Check that `next` is a strictly higher non-empty bucket, or zero. assertTrue(next == 0 || next > curr); assertTrue(next == 0 || bucketsMask & next != 0); unchecked { @@ -151,20 +152,30 @@ contract TestProveLogarithmicBuckets is LogarithmicBucketsMock, Test { } } - function testProvePrevBucket(uint256 _value) public { - uint256 curr = LogarithmicBuckets.computeBucket(_value); - uint256 prev = prevBucketValue(_value); + function testProveHighestBucket() public { + uint256 highest = highestBucketValue(); uint256 bucketsMask = buckets.bucketsMask; - // check that `prev` is a non-empty bucket that is lower than or equal to `curr`; or zero - assertTrue(prev == 0 || isPowerOfTwo(prev)); - assertTrue(prev <= curr); - assertTrue(prev == 0 || bucketsMask & prev != 0); + // check that `highest` is a power of two or zero. + assertTrue(highest == 0 || isPowerOfTwo(highest)); + // check that `highest` is a non-empty bucket or zero. + assertTrue(highest == 0 || bucketsMask & highest != 0); unchecked { - // check that `prev` is the highest one among such lower non-empty buckets, if exist - // note: this also checks that all the lower buckets are empty when `prev` == 0 - for (uint256 i = curr; i > prev; i >>= 1) { + // check that `highest` is the highest non-empty bucket, if exists. + // note: this also checks that all the lower buckets are empty when `highest` == 0 + for (uint256 i = 1 >> 256; i > highest; i >>= 1) { assertTrue(bucketsMask & i == 0); } } } + + function testProveHighestPrevEquivalence(uint256 _value) public { + uint256 curr = LogarithmicBuckets.highestSetBit(_value); + uint256 next = nextBucketValue(_value); + + if (next == 0) { + uint256 highest = highestBucketValue(); + // check that in this case, `highest` is smaller than `curr`. + assertTrue(highest <= curr); + } + } } diff --git a/test/TestLogarithmicBucketsInvariant.t.sol b/test/TestLogarithmicBucketsInvariant.t.sol index 392621cd..6d0193c8 100644 --- a/test/TestLogarithmicBucketsInvariant.t.sol +++ b/test/TestLogarithmicBucketsInvariant.t.sol @@ -3,6 +3,7 @@ pragma solidity ^0.8.0; import "./helpers/Random.sol"; import "./mocks/LogarithmicBucketsMock.sol"; +import "./mocks/BucketDLLMock.sol"; import "forge-std/Test.sol"; contract TestLogarithmicBucketsInvariant is Test, Random { @@ -31,7 +32,7 @@ contract TestLogarithmicBucketsInvariant is Test, Random { } contract LogarithmicBucketsSeenMock is LogarithmicBucketsMock { - using BucketDLL for BucketDLL.List; + using BucketDLLMock for BucketDLL.List; using LogarithmicBuckets for LogarithmicBuckets.Buckets; address[] public seen; @@ -75,7 +76,7 @@ contract TestLogarithmicBucketsSeenInvariant is Test, Random { address user = buckets.seen(i); uint256 value = buckets.getValueOf(user); if (value != 0) { - uint256 bucket = LogarithmicBuckets.computeBucket(value); + uint256 bucket = LogarithmicBuckets.highestSetBit(value); address next = buckets.getNext(bucket, user); address prev = buckets.getPrev(bucket, user); assertEq(buckets.getNext(bucket, prev), user); diff --git a/test/mocks/BucketDLLMock.sol b/test/mocks/BucketDLLMock.sol new file mode 100644 index 00000000..50857642 --- /dev/null +++ b/test/mocks/BucketDLLMock.sol @@ -0,0 +1,34 @@ +// SPDX-License-Identifier: AGPL-3.0-only +pragma solidity ^0.8.0; + +import "src/BucketDLL.sol"; + +library BucketDLLMock { + function remove(BucketDLL.List storage _list, address _id) internal returns (bool) { + return BucketDLL.remove(_list, _id); + } + + function insert( + BucketDLL.List storage _list, + address _id, + bool _head + ) internal returns (bool) { + return BucketDLL.insert(_list, _id, _head); + } + + function getNext(BucketDLL.List storage _list, address _id) internal view returns (address) { + return BucketDLL.getNext(_list, _id); + } + + function getHead(BucketDLL.List storage _list) internal view returns (address) { + return _list.accounts[address(0)].next; + } + + function getPrev(BucketDLL.List storage _list, address _id) internal view returns (address) { + return _list.accounts[_id].prev; + } + + function getTail(BucketDLL.List storage _list) internal view returns (address) { + return _list.accounts[address(0)].prev; + } +} diff --git a/test/mocks/LogarithmicBucketsMock.sol b/test/mocks/LogarithmicBucketsMock.sol index 1d73999f..e6793091 100644 --- a/test/mocks/LogarithmicBucketsMock.sol +++ b/test/mocks/LogarithmicBucketsMock.sol @@ -2,9 +2,10 @@ pragma solidity ^0.8.0; import "src/LogarithmicBuckets.sol"; +import "./BucketDLLMock.sol"; contract LogarithmicBucketsMock { - using BucketDLL for BucketDLL.List; + using BucketDLLMock for BucketDLL.List; using LogarithmicBuckets for LogarithmicBuckets.Buckets; LogarithmicBuckets.Buckets public buckets; @@ -54,9 +55,7 @@ contract LogarithmicBucketsMock { return LogarithmicBuckets.nextBucket(lowerMask, bucketsMask); } - function prevBucketValue(uint256 _value) internal view returns (uint256) { - uint256 bucketsMask = buckets.bucketsMask; - uint256 lowerMask = LogarithmicBuckets.setLowerBits(_value); - return LogarithmicBuckets.prevBucket(lowerMask, bucketsMask); + function highestBucketValue() internal view returns (uint256) { + return LogarithmicBuckets.highestSetBit(buckets.bucketsMask); } }