Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: extend signature aggregation to a variable number of signers #96

Merged
merged 5 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@
[submodule "contracts/lib/safe-modules"]
path = contracts/lib/safe-modules
url = https://github.com/worldcoin/safe-modules
[submodule "contracts/lib/solady"]
path = contracts/lib/solady
url = https://github.com/vectorized/solady
1 change: 1 addition & 0 deletions contracts/lib/solady
Submodule solady added at 717c0a
3 changes: 2 additions & 1 deletion contracts/remappings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ openzeppelin-contracts/=lib/world-id-contracts/lib/openzeppelin-contracts/contra
@4337=lib/safe-modules/modules/4337/contracts/
@safe-global/safe-contracts/contracts/=lib/safe-contracts/contracts/
@forge-std/=lib/forge-std/src/
forge-std/=lib/forge-std/src/
forge-std/=lib/forge-std/src/
@solady=lib/solady/src/utils/
23 changes: 19 additions & 4 deletions contracts/src/PBHSignatureAggregator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.28;
import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import {IPBHEntryPoint} from "./interfaces/IPBHEntryPoint.sol";
import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol";
import {ISafe} from "@4337/interfaces/Safe.sol";

/// @title PBH Signature Aggregator
/// @author Worldcoin
Expand All @@ -12,10 +13,20 @@ import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator
/// Smart Accounts that return the `PBHSignatureAggregator` as the authorizer in `validationData`
/// will be considered as Priority User Operations, and will need to pack a World ID proof in the signature field.
contract PBHSignatureAggregator is IAggregator {
/// @notice The length of an ECDSA signature.
uint256 constant ECDSA_SIGNATURE_LENGTH = 65;
/// @notice The length of the timestamp bytes.
uint256 constant TIMESTAMP_BYTES = 12; // 6 bytes each for validAfter and validUntil
/// @notice The length of the encoded proof data.
uint256 constant PROOF_DATA_LENGTH = 352;

/// @notice Thrown when the Hash of the UserOperations is not
/// in transient storage of the `PBHVerifier`.
error InvalidUserOperations();

/// @notice Thrown when the length of the signature is invalid.
error InvalidSignatureLength(uint256 expected, uint256 actual);

/// @notice The PBHVerifier contract.
IPBHEntryPoint public immutable pbhEntryPoint;

Expand Down Expand Up @@ -60,14 +71,18 @@ contract PBHSignatureAggregator is IAggregator {
*/
function aggregateSignatures(PackedUserOperation[] calldata userOps)
external
pure
view
returns (bytes memory aggregatedSignature)
{
IPBHEntryPoint.PBHPayload[] memory pbhPayloads = new IPBHEntryPoint.PBHPayload[](userOps.length);
for (uint256 i = 0; i < userOps.length; ++i) {
// Bytes (0:65) - UserOp Signature
// Bytes (65:65 + 352) - Packed Proof Data
bytes memory proofData = userOps[i].signature[65:];
uint256 expectedLength =
TIMESTAMP_BYTES + (ISafe(payable(userOps[i].sender)).getThreshold() * ECDSA_SIGNATURE_LENGTH);
require(
userOps[i].signature.length == expectedLength + PROOF_DATA_LENGTH,
InvalidSignatureLength(expectedLength + PROOF_DATA_LENGTH, userOps[i].signature.length)
);
bytes memory proofData = userOps[i].signature[expectedLength:];
pbhPayloads[i] = abi.decode(proofData, (IPBHEntryPoint.PBHPayload));
}
aggregatedSignature = abi.encode(pbhPayloads);
Expand Down
111 changes: 107 additions & 4 deletions contracts/test/PBHSignatureAggregator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {IPBHEntryPoint} from "../src/interfaces/IPBHEntryPoint.sol";
import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import "@BokkyPooBahsDateTimeLibrary/BokkyPooBahsDateTimeLibrary.sol";
import "../src/helpers/PBHExternalNullifier.sol";
import {PBHSignatureAggregator} from "../src/PBHSignatureAggregator.sol";

contract PBHSignatureAggregatorTest is TestUtils, Setup {
function setUp() public override {
Expand Down Expand Up @@ -40,7 +41,7 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
proofs[0] = abi.encode(proof0);
proofs[1] = abi.encode(proof1);

PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);

IEntryPoint.UserOpsPerAggregator[] memory userOpsPerAggregator = new IEntryPoint.UserOpsPerAggregator[](1);
Expand All @@ -53,18 +54,30 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
pbhEntryPoint.handleAggregatedOps(userOpsPerAggregator, payable(address(this)));
}

function testAggregateSignatures(uint256 root, uint256 pbhExternalNullifier, uint256 nullifierHash) public {
function testAggregateSignatures(
uint256 root,
uint256 pbhExternalNullifier,
uint256 nullifierHash,
uint256 p0,
uint256 p1,
uint256 p2,
uint256 p3,
uint256 p4,
uint256 p5,
uint256 p6,
uint256 p7
) public {
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: root,
pbhExternalNullifier: pbhExternalNullifier,
nullifierHash: nullifierHash,
proof: [uint256(0), 0, 0, 0, 0, 0, 0, 0]
proof: [p0, p1, p2, p3, p4, p5, p6, p7]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);
IPBHEntryPoint.PBHPayload[] memory decodedProofs =
abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[]));
Expand All @@ -74,12 +87,102 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match");
assertEq(decodedProofs[1].root, proof.root, "Root should match");
assertEq(
decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match");
}

function testAggregateSignatures_VariableThreshold(
uint256 root,
uint256 pbhExternalNullifier,
uint256 nullifierHash,
uint8 threshold,
uint256 p0,
uint256 p1,
uint256 p2,
uint256 p3,
uint256 p4,
uint256 p5,
uint256 p6,
uint256 p7
) public {
deploySafeAccount(address(pbhAggregator), threshold);
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: root,
pbhExternalNullifier: pbhExternalNullifier,
nullifierHash: nullifierHash,
proof: [p0, p1, p2, p3, p4, p5, p6, p7]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, threshold);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);
IPBHEntryPoint.PBHPayload[] memory decodedProofs =
abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[]));
assertEq(decodedProofs.length, 2, "Decoded proof length should be 1");
assertEq(decodedProofs[0].root, proof.root, "Root should match");
assertEq(
decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match");
assertEq(decodedProofs[1].root, proof.root, "Root should match");
assertEq(
decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match");
}

function testFailAggregateSignatures_InvalidSignatureLength() public {
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: 0,
pbhExternalNullifier: 0,
nullifierHash: 0,
proof: [uint256(1), 0, 0, 0, 0, 0, 0, 0]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
uoTestFixture[0].signature = new bytes(12);
vm.expectRevert(PBHSignatureAggregator.InvalidSignatureLength.selector);
pbhAggregator.aggregateSignatures(uoTestFixture);
}

receive() external payable {}
Expand Down
6 changes: 3 additions & 3 deletions contracts/test/Setup.sol
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ contract Setup is Test {
deployWorldIDGroups();
deployPBHEntryPoint(worldIDGroups, entryPoint);
deployPBHSignatureAggregator(address(pbhEntryPoint));
deploySafeAccount(address(pbhAggregator));
deploySafeAccount(address(pbhAggregator), 1);

// Label the addresses for better errors.
vm.label(address(entryPoint), "ERC-4337 Entry Point");
Expand Down Expand Up @@ -91,8 +91,8 @@ contract Setup is Test {

/// @notice Initializes a new safe account.
/// @dev It is constructed in the globals.
function deploySafeAccount(address _pbhSignatureAggregator) public {
safe = new MockAccount(_pbhSignatureAggregator);
function deploySafeAccount(address _pbhSignatureAggregator, uint256 threshold) public {
safe = new MockAccount(_pbhSignatureAggregator, threshold);
}

/// @notice Initializes a new World ID Groups contract.
Expand Down
32 changes: 11 additions & 21 deletions contracts/test/TestUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,27 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import {IEntryPoint} from "@account-abstraction/contracts/interfaces/IEntryPoint.sol";
import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol";
import "@forge-std/console.sol";
import "@solady/LibBytes.sol";

contract TestUtils {
function encodeSignature(bytes memory proofData) public pure returns (bytes memory) {
bytes memory sigBuffer = new bytes(65);
bytes memory signature = new bytes(417);
assembly {
mstore(signature, sigBuffer)
mstore(add(signature, 32), mload(add(sigBuffer, 32)))
mstore(add(signature, 64), mload(add(sigBuffer, 64)))
mstore(add(signature, 65), mload(proofData))
mstore(add(add(signature, 65), 32), mload(add(proofData, 32)))
mstore(add(add(signature, 65), 64), mload(add(proofData, 64)))
mstore(add(add(signature, 65), 96), mload(add(proofData, 96)))
mstore(add(add(signature, 65), 128), mload(add(proofData, 128)))
mstore(add(add(signature, 65), 160), mload(add(proofData, 160)))
mstore(add(add(signature, 65), 192), mload(add(proofData, 192)))
mstore(add(add(signature, 65), 224), mload(add(proofData, 224)))
mstore(add(add(signature, 65), 256), mload(add(proofData, 256)))
}
return signature;
function encodeSignature(bytes memory proofData, uint256 signatureThreshold)
public
pure
returns (bytes memory res)
{
bytes memory sigBuffer = new bytes(65 * signatureThreshold + 12);
res = LibBytes.concat(sigBuffer, proofData);
}

/// @notice Create a test data for UserOperations.
function createUOTestData(address sender, bytes[] memory proofs)
function createUOTestData(address sender, bytes[] memory proofs, uint256 signatureThreshold)
public
pure
returns (PackedUserOperation[] memory)
{
PackedUserOperation[] memory uOps = new PackedUserOperation[](proofs.length);
for (uint256 i = 0; i < proofs.length; ++i) {
bytes memory signature = encodeSignature(proofs[i]);
for (uint256 i = 0; i < proofs.length; i++) {
bytes memory signature = encodeSignature(proofs[i], signatureThreshold);
PackedUserOperation memory uo = PackedUserOperation({
sender: sender,
nonce: i,
Expand Down
8 changes: 7 additions & 1 deletion contracts/test/mocks/MockAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";

contract MockAccount is IAccount, IAccountExecute {
address public pbhAggregator;
uint256 public immutable threshold;

constructor(address _pbhAggregator) {
constructor(address _pbhAggregator, uint256 _threshold) {
pbhAggregator = _pbhAggregator;
threshold = _threshold;
}

function validateUserOp(PackedUserOperation calldata, bytes32, uint256)
Expand All @@ -26,4 +28,8 @@ contract MockAccount is IAccount, IAccountExecute {
function executeUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external {
// Do nothing
}

function getThreshold() external view returns (uint256) {
return threshold;
}
}
Loading