Skip to content

Commit

Permalink
Merge pull request #96 from worldcoin/osiris/threshold-signature-aggr…
Browse files Browse the repository at this point in the history
…egation

fix: extend signature aggregation to a variable number of signers
  • Loading branch information
0xOsiris authored Dec 27, 2024
2 parents 5268032 + 1e1f6a1 commit a557528
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,6 @@
[submodule "contracts/lib/safe-modules"]
path = contracts/lib/safe-modules
url = https://github.com/worldcoin/safe-modules
[submodule "contracts/lib/solady"]
path = contracts/lib/solady
url = https://github.com/vectorized/solady
1 change: 1 addition & 0 deletions contracts/lib/solady
Submodule solady added at 717c0a
3 changes: 2 additions & 1 deletion contracts/remappings.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ openzeppelin-contracts/=lib/world-id-contracts/lib/openzeppelin-contracts/contra
@4337=lib/safe-modules/modules/4337/contracts/
@safe-global/safe-contracts/contracts/=lib/safe-contracts/contracts/
@forge-std/=lib/forge-std/src/
forge-std/=lib/forge-std/src/
forge-std/=lib/forge-std/src/
@solady=lib/solady/src/utils/
23 changes: 19 additions & 4 deletions contracts/src/PBHSignatureAggregator.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pragma solidity ^0.8.28;
import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import {IPBHEntryPoint} from "./interfaces/IPBHEntryPoint.sol";
import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol";
import {ISafe} from "@4337/interfaces/Safe.sol";

/// @title PBH Signature Aggregator
/// @author Worldcoin
Expand All @@ -12,10 +13,20 @@ import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator
/// Smart Accounts that return the `PBHSignatureAggregator` as the authorizer in `validationData`
/// will be considered as Priority User Operations, and will need to pack a World ID proof in the signature field.
contract PBHSignatureAggregator is IAggregator {
/// @notice The length of an ECDSA signature.
uint256 constant ECDSA_SIGNATURE_LENGTH = 65;
/// @notice The length of the timestamp bytes.
uint256 constant TIMESTAMP_BYTES = 12; // 6 bytes each for validAfter and validUntil
/// @notice The length of the encoded proof data.
uint256 constant PROOF_DATA_LENGTH = 352;

/// @notice Thrown when the Hash of the UserOperations is not
/// in transient storage of the `PBHVerifier`.
error InvalidUserOperations();

/// @notice Thrown when the length of the signature is invalid.
error InvalidSignatureLength(uint256 expected, uint256 actual);

/// @notice The PBHVerifier contract.
IPBHEntryPoint public immutable pbhEntryPoint;

Expand Down Expand Up @@ -60,14 +71,18 @@ contract PBHSignatureAggregator is IAggregator {
*/
function aggregateSignatures(PackedUserOperation[] calldata userOps)
external
pure
view
returns (bytes memory aggregatedSignature)
{
IPBHEntryPoint.PBHPayload[] memory pbhPayloads = new IPBHEntryPoint.PBHPayload[](userOps.length);
for (uint256 i = 0; i < userOps.length; ++i) {
// Bytes (0:65) - UserOp Signature
// Bytes (65:65 + 352) - Packed Proof Data
bytes memory proofData = userOps[i].signature[65:];
uint256 expectedLength =
TIMESTAMP_BYTES + (ISafe(payable(userOps[i].sender)).getThreshold() * ECDSA_SIGNATURE_LENGTH);
require(
userOps[i].signature.length == expectedLength + PROOF_DATA_LENGTH,
InvalidSignatureLength(expectedLength + PROOF_DATA_LENGTH, userOps[i].signature.length)
);
bytes memory proofData = userOps[i].signature[expectedLength:];
pbhPayloads[i] = abi.decode(proofData, (IPBHEntryPoint.PBHPayload));
}
aggregatedSignature = abi.encode(pbhPayloads);
Expand Down
111 changes: 107 additions & 4 deletions contracts/test/PBHSignatureAggregator.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import {IPBHEntryPoint} from "../src/interfaces/IPBHEntryPoint.sol";
import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import "@BokkyPooBahsDateTimeLibrary/BokkyPooBahsDateTimeLibrary.sol";
import "../src/helpers/PBHExternalNullifier.sol";
import {PBHSignatureAggregator} from "../src/PBHSignatureAggregator.sol";

contract PBHSignatureAggregatorTest is TestUtils, Setup {
function setUp() public override {
Expand Down Expand Up @@ -40,7 +41,7 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
proofs[0] = abi.encode(proof0);
proofs[1] = abi.encode(proof1);

PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);

IEntryPoint.UserOpsPerAggregator[] memory userOpsPerAggregator = new IEntryPoint.UserOpsPerAggregator[](1);
Expand All @@ -53,18 +54,30 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
pbhEntryPoint.handleAggregatedOps(userOpsPerAggregator, payable(address(this)));
}

function testAggregateSignatures(uint256 root, uint256 pbhExternalNullifier, uint256 nullifierHash) public {
function testAggregateSignatures(
uint256 root,
uint256 pbhExternalNullifier,
uint256 nullifierHash,
uint256 p0,
uint256 p1,
uint256 p2,
uint256 p3,
uint256 p4,
uint256 p5,
uint256 p6,
uint256 p7
) public {
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: root,
pbhExternalNullifier: pbhExternalNullifier,
nullifierHash: nullifierHash,
proof: [uint256(0), 0, 0, 0, 0, 0, 0, 0]
proof: [p0, p1, p2, p3, p4, p5, p6, p7]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);
IPBHEntryPoint.PBHPayload[] memory decodedProofs =
abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[]));
Expand All @@ -74,12 +87,102 @@ contract PBHSignatureAggregatorTest is TestUtils, Setup {
decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match");
assertEq(decodedProofs[1].root, proof.root, "Root should match");
assertEq(
decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match");
}

function testAggregateSignatures_VariableThreshold(
uint256 root,
uint256 pbhExternalNullifier,
uint256 nullifierHash,
uint8 threshold,
uint256 p0,
uint256 p1,
uint256 p2,
uint256 p3,
uint256 p4,
uint256 p5,
uint256 p6,
uint256 p7
) public {
deploySafeAccount(address(pbhAggregator), threshold);
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: root,
pbhExternalNullifier: pbhExternalNullifier,
nullifierHash: nullifierHash,
proof: [p0, p1, p2, p3, p4, p5, p6, p7]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, threshold);
bytes memory aggregatedSignature = pbhAggregator.aggregateSignatures(uoTestFixture);
IPBHEntryPoint.PBHPayload[] memory decodedProofs =
abi.decode(aggregatedSignature, (IPBHEntryPoint.PBHPayload[]));
assertEq(decodedProofs.length, 2, "Decoded proof length should be 1");
assertEq(decodedProofs[0].root, proof.root, "Root should match");
assertEq(
decodedProofs[0].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[0].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[0].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[0].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[0].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[0].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[0].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[0].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[0].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[0].proof[7], proof.proof[7], "Proof should match");
assertEq(decodedProofs[1].root, proof.root, "Root should match");
assertEq(
decodedProofs[1].pbhExternalNullifier, proof.pbhExternalNullifier, "PBH External Nullifier should match"
);
assertEq(decodedProofs[1].nullifierHash, proof.nullifierHash, "Nullifier Hash should match");
assertEq(decodedProofs[1].proof[0], proof.proof[0], "Proof should match");
assertEq(decodedProofs[1].proof[1], proof.proof[1], "Proof should match");
assertEq(decodedProofs[1].proof[2], proof.proof[2], "Proof should match");
assertEq(decodedProofs[1].proof[3], proof.proof[3], "Proof should match");
assertEq(decodedProofs[1].proof[4], proof.proof[4], "Proof should match");
assertEq(decodedProofs[1].proof[5], proof.proof[5], "Proof should match");
assertEq(decodedProofs[1].proof[6], proof.proof[6], "Proof should match");
assertEq(decodedProofs[1].proof[7], proof.proof[7], "Proof should match");
}

function testFailAggregateSignatures_InvalidSignatureLength() public {
IPBHEntryPoint.PBHPayload memory proof = IPBHEntryPoint.PBHPayload({
root: 0,
pbhExternalNullifier: 0,
nullifierHash: 0,
proof: [uint256(1), 0, 0, 0, 0, 0, 0, 0]
});

bytes[] memory proofs = new bytes[](2);
proofs[0] = abi.encode(proof);
proofs[1] = abi.encode(proof);
PackedUserOperation[] memory uoTestFixture = createUOTestData(address(safe), proofs, 1);
uoTestFixture[0].signature = new bytes(12);
vm.expectRevert(PBHSignatureAggregator.InvalidSignatureLength.selector);
pbhAggregator.aggregateSignatures(uoTestFixture);
}

receive() external payable {}
Expand Down
6 changes: 3 additions & 3 deletions contracts/test/Setup.sol
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ contract Setup is Test {
deployWorldIDGroups();
deployPBHEntryPoint(worldIDGroups, entryPoint);
deployPBHSignatureAggregator(address(pbhEntryPoint));
deploySafeAccount(address(pbhAggregator));
deploySafeAccount(address(pbhAggregator), 1);

// Label the addresses for better errors.
vm.label(address(entryPoint), "ERC-4337 Entry Point");
Expand Down Expand Up @@ -91,8 +91,8 @@ contract Setup is Test {

/// @notice Initializes a new safe account.
/// @dev It is constructed in the globals.
function deploySafeAccount(address _pbhSignatureAggregator) public {
safe = new MockAccount(_pbhSignatureAggregator);
function deploySafeAccount(address _pbhSignatureAggregator, uint256 threshold) public {
safe = new MockAccount(_pbhSignatureAggregator, threshold);
}

/// @notice Initializes a new World ID Groups contract.
Expand Down
32 changes: 11 additions & 21 deletions contracts/test/TestUtils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,27 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";
import {IEntryPoint} from "@account-abstraction/contracts/interfaces/IEntryPoint.sol";
import {IAggregator} from "@account-abstraction/contracts/interfaces/IAggregator.sol";
import "@forge-std/console.sol";
import "@solady/LibBytes.sol";

contract TestUtils {
function encodeSignature(bytes memory proofData) public pure returns (bytes memory) {
bytes memory sigBuffer = new bytes(65);
bytes memory signature = new bytes(417);
assembly {
mstore(signature, sigBuffer)
mstore(add(signature, 32), mload(add(sigBuffer, 32)))
mstore(add(signature, 64), mload(add(sigBuffer, 64)))
mstore(add(signature, 65), mload(proofData))
mstore(add(add(signature, 65), 32), mload(add(proofData, 32)))
mstore(add(add(signature, 65), 64), mload(add(proofData, 64)))
mstore(add(add(signature, 65), 96), mload(add(proofData, 96)))
mstore(add(add(signature, 65), 128), mload(add(proofData, 128)))
mstore(add(add(signature, 65), 160), mload(add(proofData, 160)))
mstore(add(add(signature, 65), 192), mload(add(proofData, 192)))
mstore(add(add(signature, 65), 224), mload(add(proofData, 224)))
mstore(add(add(signature, 65), 256), mload(add(proofData, 256)))
}
return signature;
function encodeSignature(bytes memory proofData, uint256 signatureThreshold)
public
pure
returns (bytes memory res)
{
bytes memory sigBuffer = new bytes(65 * signatureThreshold + 12);
res = LibBytes.concat(sigBuffer, proofData);
}

/// @notice Create a test data for UserOperations.
function createUOTestData(address sender, bytes[] memory proofs)
function createUOTestData(address sender, bytes[] memory proofs, uint256 signatureThreshold)
public
pure
returns (PackedUserOperation[] memory)
{
PackedUserOperation[] memory uOps = new PackedUserOperation[](proofs.length);
for (uint256 i = 0; i < proofs.length; ++i) {
bytes memory signature = encodeSignature(proofs[i]);
for (uint256 i = 0; i < proofs.length; i++) {
bytes memory signature = encodeSignature(proofs[i], signatureThreshold);
PackedUserOperation memory uo = PackedUserOperation({
sender: sender,
nonce: i,
Expand Down
8 changes: 7 additions & 1 deletion contracts/test/mocks/MockAccount.sol
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@ import "@account-abstraction/contracts/interfaces/PackedUserOperation.sol";

contract MockAccount is IAccount, IAccountExecute {
address public pbhAggregator;
uint256 public immutable threshold;

constructor(address _pbhAggregator) {
constructor(address _pbhAggregator, uint256 _threshold) {
pbhAggregator = _pbhAggregator;
threshold = _threshold;
}

function validateUserOp(PackedUserOperation calldata, bytes32, uint256)
Expand All @@ -26,4 +28,8 @@ contract MockAccount is IAccount, IAccountExecute {
function executeUserOp(PackedUserOperation calldata userOp, bytes32 userOpHash) external {
// Do nothing
}

function getThreshold() external view returns (uint256) {
return threshold;
}
}

0 comments on commit a557528

Please sign in to comment.