diff --git a/.gitmodules b/.gitmodules index ddf6254e..3703f9fc 100644 --- a/.gitmodules +++ b/.gitmodules @@ -22,3 +22,6 @@ [submodule "contracts/lib/safe-modules"] path = contracts/lib/safe-modules url = https://github.com/worldcoin/safe-modules +[submodule "contracts/lib/solady"] + path = contracts/lib/solady + url = https://github.com/vectorized/solady diff --git a/contracts/lib/solady b/contracts/lib/solady new file mode 160000 index 00000000..717c0a86 --- /dev/null +++ b/contracts/lib/solady @@ -0,0 +1 @@ +Subproject commit 717c0a860a83dc9a1520d32384577aec51cecd84 diff --git a/contracts/remappings.txt b/contracts/remappings.txt index cb51404c..77de759c 100644 --- a/contracts/remappings.txt +++ b/contracts/remappings.txt @@ -8,4 +8,5 @@ openzeppelin-contracts/=lib/world-id-contracts/lib/openzeppelin-contracts/contra @4337=lib/safe-modules/modules/4337/contracts/ @safe-global/safe-contracts/contracts/=lib/safe-contracts/contracts/ @forge-std/=lib/forge-std/src/ -forge-std/=lib/forge-std/src/ \ No newline at end of file +forge-std/=lib/forge-std/src/ +@solady=lib/solady/src/utils/ diff --git a/contracts/src/PBHSignatureAggregator.sol b/contracts/src/PBHSignatureAggregator.sol index e4d2255d..e1e4b945 100644 --- a/contracts/src/PBHSignatureAggregator.sol +++ b/contracts/src/PBHSignatureAggregator.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.28; import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol"; import {IPBHEntryPoint} from "./interfaces/IPBHEntryPoint.sol"; import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol"; +import {ISafe} from "@4337/interfaces/Safe.sol"; /// @title PBH Signature Aggregator /// @author Worldcoin @@ -12,10 +13,20 @@ import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator /// Smart Accounts that return the `PBHSignatureAggregator` as the authorizer in `validationData` /// will be considered as Priority User Operations, and will need to pack a World ID proof in the signature field. contract PBHSignatureAggregator is IAggregator { + /// @notice The length of an ECDSA signature. + uint256 constant ECDSA_SIGNATURE_LENGTH = 65; + /// @notice The length of the timestamp bytes. + uint256 constant TIMESTAMP_BYTES = 12; // 6 bytes each for validAfter and validUntil + /// @notice The length of the encoded proof data. + uint256 constant PROOF_DATA_LENGTH = 352; + /// @notice Thrown when the Hash of the UserOperations is not /// in transient storage of the `PBHVerifier`. error InvalidUserOperations(); + /// @notice Thrown when the length of the signature is invalid. + error InvalidSignatureLength(uint256 expected, uint256 actual); + /// @notice The PBHVerifier contract. IPBHEntryPoint public immutable pbhEntryPoint; @@ -60,14 +71,18 @@ contract PBHSignatureAggregator is IAggregator { */ function aggregateSignatures(PackedUserOperation[] calldata userOps) external - pure + view returns (bytes memory aggregatedSignature) { IPBHEntryPoint.PBHPayload[] memory pbhPayloads = new IPBHEntryPoint.PBHPayload[](userOps.length); for (uint256 i = 0; i < userOps.length; ++i) { - // Bytes (0:65) - UserOp Signature - // Bytes (65:65 + 352) - Packed Proof Data - bytes memory proofData = userOps[i].signature[65:]; + uint256 expectedLength = + TIMESTAMP_BYTES + (ISafe(payable(userOps[i].sender)).getThreshold() * ECDSA_SIGNATURE_LENGTH); + require( + userOps[i].signature.length == expectedLength + PROOF_DATA_LENGTH, + InvalidSignatureLength(expectedLength + PROOF_DATA_LENGTH, userOps[i].signature.length) + ); + bytes memory proofData = userOps[i].signature[expectedLength:]; pbhPayloads[i] = abi.decode(proofData, (IPBHEntryPoint.PBHPayload)); } aggregatedSignature = abi.encode(pbhPayloads); diff --git a/contracts/test/PBHSignatureAggregator.t.sol b/contracts/test/PBHSignatureAggregator.t.sol index 03099185..648869b9 100644 --- a/contracts/test/PBHSignatureAggregator.t.sol +++ b/contracts/test/PBHSignatureAggregator.t.sol @@ -9,6 +9,7 @@ import {IPBHEntryPoint} from "../src/interfaces/IPBHEntryPoint.sol"; import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol"; import "@BokkyPooBahsDateTimeLibrary/BokkyPooBahsDateTimeLibrary.sol"; import "../src/helpers/PBHExternalNullifier.sol"; +import {PBHSignatureAggregator} from "../src/PBHSignatureAggregator.sol"; contract PBHSignatureAggregatorTest is TestUtils, Setup { function setUp() public override { @@ -40,7 +41,7 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup { proofs[0] = abi.encode(proof0); proofs[1] = abi.encode(proof1); - PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs); + PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1); bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture); IEntryPoint.UserOpsPerAggregator[] memory userOpsPerAggregator = new IEntryPoint.UserOpsPerAggregator[](1); @@ -53,18 +54,30 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup { pbhEntryPoint.handleAggregatedOps(userOpsPerAggregator, payable(address(this))); } - function testAggregateSignatures(uint256 root, uint256 pbhExternalNullifier, uint256 nullifierHash) public { + function testAggregateSignatures( + uint256 root, + uint256 pbhExternalNullifier, + uint256 nullifierHash, + uint256 p0, + uint256 p1, + uint256 p2, + uint256 p3, + uint256 p4, + uint256 p5, + uint256 p6, + uint256 p7 + ) public { IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({ root: root, pbhExternalNullifier: pbhExternalNullifier, nullifierHash: nullifierHash, - proof: [uint256(0), 0, 0, 0, 0, 0, 0, 0] + proof: [p0, p1, p2, p3, p4, p5, p6, p7] }); bytes[] memory proofs = new bytes[](2); proofs[0] = abi.encode(proof); proofs[1] = abi.encode(proof); - PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs); + PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1); bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture); IPBHEntryPoint.PBHPayload[] memory decodedProofs = abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[])); @@ -74,12 +87,102 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup { decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match" ); assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match"); + assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match"); + assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match"); + assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match"); + assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match"); + assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match"); + assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match"); + assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match"); + assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match"); + assertEq(decodedProofs[1].root, proof.root, "Root should match"); + assertEq( + decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match" + ); + assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match"); + assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match"); + assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match"); + assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match"); + assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match"); + assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match"); + assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match"); + assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match"); + assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match"); + } + function testAggregateSignatures_VariableThreshold( + uint256 root, + uint256 pbhExternalNullifier, + uint256 nullifierHash, + uint8 threshold, + uint256 p0, + uint256 p1, + uint256 p2, + uint256 p3, + uint256 p4, + uint256 p5, + uint256 p6, + uint256 p7 + ) public { + deploySafeAccount(address(pbhAggregator), threshold); + IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({ + root: root, + pbhExternalNullifier: pbhExternalNullifier, + nullifierHash: nullifierHash, + proof: [p0, p1, p2, p3, p4, p5, p6, p7] + }); + + bytes[] memory proofs = new bytes[](2); + proofs[0] = abi.encode(proof); + proofs[1] = abi.encode(proof); + PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, threshold); + bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture); + IPBHEntryPoint.PBHPayload[] memory decodedProofs = + abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[])); + assertEq(decodedProofs.length, 2, "Decoded proof length should be 1"); + assertEq(decodedProofs[0].root, proof.root, "Root should match"); + assertEq( + decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match" + ); + assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match"); + assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match"); + assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match"); + assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match"); + assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match"); + assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match"); + assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match"); + assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match"); + assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match"); assertEq(decodedProofs[1].root, proof.root, "Root should match"); assertEq( decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match" ); assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match"); + assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match"); + assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match"); + assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match"); + assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match"); + assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match"); + assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match"); + assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match"); + assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match"); + } + + function testFailAggregateSignatures_InvalidSignatureLength() public { + IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({ + root: 0, + pbhExternalNullifier: 0, + nullifierHash: 0, + proof: [uint256(1), 0, 0, 0, 0, 0, 0, 0] + }); + + bytes[] memory proofs = new bytes[](2); + proofs[0] = abi.encode(proof); + proofs[1] = abi.encode(proof); + PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1); + uoTestFixture[0].signature = new bytes(12); + vm.expectRevert(PBHSignatureAggregator.InvalidSignatureLength.selector); + pbhAggregator.aggregateSignatures(uoTestFixture); } receive() external payable {} diff --git a/contracts/test/Setup.sol b/contracts/test/Setup.sol index 81e132d1..2436953b 100644 --- a/contracts/test/Setup.sol +++ b/contracts/test/Setup.sol @@ -47,7 +47,7 @@ contract Setup is Test { deployWorldIDGroups(); deployPBHEntryPoint(worldIDGroups, entryPoint); deployPBHSignatureAggregator(address(pbhEntryPoint)); - deploySafeAccount(address(pbhAggregator)); + deploySafeAccount(address(pbhAggregator), 1); // Label the addresses for better errors. vm.label(address(entryPoint), "ERC-4337 Entry Point"); @@ -91,8 +91,8 @@ contract Setup is Test { /// @notice Initializes a new safe account. /// @dev It is constructed in the globals. - function deploySafeAccount(address _pbhSignatureAggregator) public { - safe = new MockAccount(_pbhSignatureAggregator); + function deploySafeAccount(address _pbhSignatureAggregator, uint256 threshold) public { + safe = new MockAccount(_pbhSignatureAggregator, threshold); } /// @notice Initializes a new World ID Groups contract. diff --git a/contracts/test/TestUtils.sol b/contracts/test/TestUtils.sol index ac3b093d..499da634 100644 --- a/contracts/test/TestUtils.sol +++ b/contracts/test/TestUtils.sol @@ -5,37 +5,27 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol"; import {IEntryPoint} from "@account-abstraction/contracts/interfaces/IEntryPoint.sol"; import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol"; import "@forge-std/console.sol"; +import "@solady/LibBytes.sol"; contract TestUtils { - function encodeSignature(bytes memory proofData) public pure returns (bytes memory) { - bytes memory sigBuffer = new bytes(65); - bytes memory signature = new bytes(417); - assembly { - mstore(signature, sigBuffer) - mstore(add(signature, 32), mload(add(sigBuffer, 32))) - mstore(add(signature, 64), mload(add(sigBuffer, 64))) - mstore(add(signature, 65), mload(proofData)) - mstore(add(add(signature, 65), 32), mload(add(proofData, 32))) - mstore(add(add(signature, 65), 64), mload(add(proofData, 64))) - mstore(add(add(signature, 65), 96), mload(add(proofData, 96))) - mstore(add(add(signature, 65), 128), mload(add(proofData, 128))) - mstore(add(add(signature, 65), 160), mload(add(proofData, 160))) - mstore(add(add(signature, 65), 192), mload(add(proofData, 192))) - mstore(add(add(signature, 65), 224), mload(add(proofData, 224))) - mstore(add(add(signature, 65), 256), mload(add(proofData, 256))) - } - return signature; + function encodeSignature(bytes memory proofData, uint256 signatureThreshold) + public + pure + returns (bytes memory res) + { + bytes memory sigBuffer = new bytes(65 * signatureThreshold + 12); + res = LibBytes.concat(sigBuffer, proofData); } /// @notice Create a test data for UserOperations. - function createUOTestData(address sender, bytes[] memory proofs) + function createUOTestData(address sender, bytes[] memory proofs, uint256 signatureThreshold) public pure returns (PackedUserOperation[] memory) { PackedUserOperation[] memory uOps = new PackedUserOperation[](proofs.length); - for (uint256 i = 0; i < proofs.length; ++i) { - bytes memory signature = encodeSignature(proofs[i]); + for (uint256 i = 0; i < proofs.length; i++) { + bytes memory signature = encodeSignature(proofs[i], signatureThreshold); PackedUserOperation memory uo = PackedUserOperation({ sender: sender, nonce: i, diff --git a/contracts/test/mocks/MockAccount.sol b/contracts/test/mocks/MockAccount.sol index 5223f880..93c4a941 100644 --- a/contracts/test/mocks/MockAccount.sol +++ b/contracts/test/mocks/MockAccount.sol @@ -8,9 +8,11 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol"; contract MockAccount is IAccount, IAccountExecute { address public pbhAggregator; + uint256 public immutable threshold; - constructor(address _pbhAggregator) { + constructor(address _pbhAggregator, uint256 _threshold) { pbhAggregator = _pbhAggregator; + threshold = _threshold; } function validateUserOp(PackedUserOperation calldata, bytes32, uint256) @@ -26,4 +28,8 @@ contract MockAccount is IAccount, IAccountExecute { function executeUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external { // Do nothing } + + function getThreshold() external view returns (uint256) { + return threshold; + } }