Skip to content

Commit

Permalink
Merge branch 'development'
Browse files Browse the repository at this point in the history
  • Loading branch information
lykhonis committed Mar 26, 2024
2 parents ddcfd44 + f1a36f3 commit ee17375
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 25 deletions.
111 changes: 87 additions & 24 deletions src/pool/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,58 @@ pragma solidity =0.8.22;
import {OwnableUnset} from "@erc725/smart-contracts/contracts/custom/OwnableUnset.sol";
import {ReentrancyGuardUpgradeable} from "@openzeppelin/contracts-upgradeable/security/ReentrancyGuardUpgradeable.sol";
import {PausableUpgradeable} from "@openzeppelin/contracts-upgradeable/security/PausableUpgradeable.sol";
import {ERC165} from "openzeppelin-contracts/contracts/utils/introspection/ERC165.sol";
import {ERC165Checker} from "openzeppelin-contracts/contracts/utils/introspection/ERC165Checker.sol";
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
import {IDepositContract, DEPOSIT_AMOUNT} from "./IDepositContract.sol";

contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable {
interface IVault {
event Deposited(address indexed account, address indexed beneficiary, uint256 amount);
event Withdrawn(address indexed account, address indexed beneficiary, uint256 amount);
event WithdrawalRequested(address indexed account, address indexed beneficiary, uint256 amount);
event Claimed(address indexed account, address indexed beneficiary, uint256 amount);
event DepositLimitChanged(uint256 previousLimit, uint256 newLimit);
event FeeChanged(uint32 previousFee, uint32 newFee);
event FeeRecipientChanged(address previousFeeRecipient, address newFeeRecipient);
event FeeClaimed(address indexed account, address indexed beneficiary, uint256 amount);
event RewardsDistributed(uint256 balance, uint256 rewards, uint256 fee);
event OracleEnabled(address indexed oracle, bool enabled);
event Rebalanced(
uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked
);
event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data);

function depositLimit() external view returns (uint256);
function totalAssets() external view returns (uint256);
function totalShares() external view returns (uint256);
function totalStaked() external view returns (uint256);
function totalUnstaked() external view returns (uint256);
function totalPendingWithdrawal() external view returns (uint256);
function totalValidatorsRegistered() external view returns (uint256);
function fee() external view returns (uint32);
function feeRecipient() external view returns (address);
function totalFees() external view returns (uint256);
function restricted() external view returns (bool);
function balanceOf(address account) external view returns (uint256);
function sharesOf(address account) external view returns (uint256);
function pendingBalanceOf(address account) external view returns (uint256);
function claimableBalanceOf(address account) external view returns (uint256);
function deposit(address beneficiary) external payable;
function withdraw(uint256 amount, address beneficiary) external;
function claim(uint256 amount, address beneficiary) external;
function claimFees(uint256 amount, address beneficiary) external;
function transferShares(address to, uint256 amount, bytes calldata data) external;
}

interface IVaultSharesRecipient {
/// @notice Shares have been transfered to the recipient.
/// @param from The address of the sender.
/// @param amount The amount of shares.
/// @param data Additional data.
function onSharesReceived(address from, uint256 amount, bytes calldata data) external;
}

contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable {
uint32 private constant _FEE_BASIS = 100_000;
uint32 private constant _MIN_FEE = 0; // 0%
uint32 private constant _MAX_FEE = 15_000; // 15%
Expand All @@ -25,20 +73,7 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
error InvalidAddress(address account);
error ValidatorAlreadyRegistered(bytes pubkey);
error CallerNotOperator(address account);

event Deposited(address indexed account, address indexed beneficiary, uint256 amount);
event Withdrawn(address indexed account, address indexed beneficiary, uint256 amount);
event WithdrawalRequested(address indexed account, address indexed beneficiary, uint256 amount);
event Claimed(address indexed account, address indexed beneficiary, uint256 amount);
event DepositLimitChanged(uint256 previousLimit, uint256 newLimit);
event FeeChanged(uint32 previousFee, uint32 newFee);
event FeeRecipientChanged(address previousFeeRecipient, address newFeeRecipient);
event FeeClaimed(address indexed account, address indexed beneficiary, uint256 amount);
event RewardsDistributed(uint256 balance, uint256 rewards, uint256 fee);
event OracleEnabled(address indexed oracle, bool enabled);
event Rebalanced(
uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked
);
error InvalidSignature();

// limit of total deposits in wei.
// This limits the total number of validators that can be registered.
Expand Down Expand Up @@ -86,6 +121,10 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
_disableInitializers();
}

function supportsInterface(bytes4 interfaceId) public view override returns (bool) {
return interfaceId == type(IVault).interfaceId || super.supportsInterface(interfaceId);
}

function initialize(address owner_, address operator_, IDepositContract depositContract_) external initializer {
if (address(depositContract_) == address(0)) {
revert InvalidAddress(address(depositContract_));
Expand Down Expand Up @@ -187,24 +226,24 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
}
}

function sharesOf(address account) external view returns (uint256) {
function sharesOf(address account) external view override returns (uint256) {
return _shares[account];
}

function balanceOf(address account) public view returns (uint256) {
function balanceOf(address account) public view override returns (uint256) {
return _toBalance(_shares[account]);
}

function pendingBalanceOf(address account) external view returns (uint256) {
function pendingBalanceOf(address account) external view override returns (uint256) {
return _pendingWithdrawals[account];
}

function claimableBalanceOf(address account) external view returns (uint256) {
function claimableBalanceOf(address account) external view override returns (uint256) {
uint256 pendingWithdrawal = _pendingWithdrawals[account];
return pendingWithdrawal > totalClaimable ? totalClaimable : pendingWithdrawal;
}

function claim(uint256 amount, address beneficiary) external nonReentrant whenNotPaused {
function claim(uint256 amount, address beneficiary) external override nonReentrant whenNotPaused {
if (beneficiary == address(0)) {
revert InvalidAddress(beneficiary);
}
Expand All @@ -228,7 +267,7 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
emit Claimed(account, beneficiary, amount);
}

function totalAssets() public view returns (uint256) {
function totalAssets() public view override returns (uint256) {
return totalStaked + totalUnstaked + totalClaimable - totalPendingWithdrawal;
}

Expand All @@ -248,7 +287,7 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
return Math.mulDiv(amount, totalShares, totalAssets());
}

function deposit(address beneficiary) public payable whenNotPaused {
function deposit(address beneficiary) public payable override whenNotPaused {
if (beneficiary == address(0)) {
revert InvalidAddress(beneficiary);
}
Expand Down Expand Up @@ -284,7 +323,7 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
emit Deposited(account, beneficiary, amount);
}

function withdraw(uint256 amount, address beneficiary) external nonReentrant whenNotPaused {
function withdraw(uint256 amount, address beneficiary) external override nonReentrant whenNotPaused {
if (beneficiary == address(0)) {
revert InvalidAddress(beneficiary);
}
Expand Down Expand Up @@ -325,7 +364,7 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
}
}

function claimFees(uint256 amount, address beneficiary) external nonReentrant whenNotPaused {
function claimFees(uint256 amount, address beneficiary) external override nonReentrant whenNotPaused {
if (beneficiary == address(0)) {
revert InvalidAddress(beneficiary);
}
Expand Down Expand Up @@ -439,4 +478,28 @@ contract Vault is OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable
registerValidator(pubkeys[i], signatures[i], depositDataRoots[i]);
}
}

function transferShares(address to, uint256 amount, bytes calldata data)
external
override
whenNotPaused
nonReentrant
{
address account = msg.sender;
if (amount == 0) {
revert InvalidAmount(amount);
}
if (amount > _shares[account]) {
revert InsufficientBalance(_shares[account], amount);
}
if (to == address(0)) {
revert InvalidAddress(to);
}
_shares[account] -= amount;
_shares[to] += amount;
emit SharesTransferred(account, to, amount, data);
if (ERC165Checker.supportsInterface(to, type(IVaultSharesRecipient).interfaceId)) {
IVaultSharesRecipient(to).onSharesReceived(account, amount, data);
}
}
}
127 changes: 126 additions & 1 deletion test/pool/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ import {Test} from "forge-std/Test.sol";
import {TransparentUpgradeableProxy} from "@openzeppelin/contracts/proxy/transparent/TransparentUpgradeableProxy.sol";
import {OwnableCallerNotTheOwner} from "@erc725/smart-contracts/contracts/errors.sol";
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
import {Vault} from "../../src/pool/Vault.sol";
import {ERC165} from "openzeppelin-contracts/contracts/utils/introspection/ERC165.sol";
import {ERC165Checker} from "openzeppelin-contracts/contracts/utils/introspection/ERC165Checker.sol";
import {IVault, Vault, IVaultSharesRecipient} from "../../src/pool/Vault.sol";
import {IDepositContract} from "../../src/pool/IDepositContract.sol";

contract VaultTest is Test {
Expand All @@ -24,6 +26,7 @@ contract VaultTest is Test {
uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked
);
event ValidatorExited(bytes pubkey, uint256 total);
event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data);

Vault vault;
address admin;
Expand Down Expand Up @@ -63,6 +66,7 @@ contract VaultTest is Test {
}

function test_Initialize() public {
assertTrue(ERC165Checker.supportsInterface(address(vault), type(IVault).interfaceId));
assertTrue(!vault.paused());
assertEq(owner, vault.owner());
assertEq(0, vault.depositLimit());
Expand Down Expand Up @@ -1590,6 +1594,127 @@ contract VaultTest is Test {
vault.balanceOf(bob) >= 1 ether - 1e15 /* rounding error of 18 decimals - 3 of minimum shares amount */
);
}

function test_TransferShares() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
vault.setFee(10_000);
vault.setFeeRecipient(feeRecipient);
vm.stopPrank();

address alice = vm.addr(100);
address bob = vm.addr(101);

vm.deal(alice, 100 ether);
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);
uint256 balance = vault.balanceOf(alice);

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));

vm.prank(alice);
vm.expectEmit();
emit SharesTransferred(alice, bob, shares, "0x1234");
vault.transferShares(bob, shares, "0x1234");

assertEq(shares, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(alice));
assertEq(0, vault.sharesOf(alice));
assertEq(0, bob.balance);

vm.prank(bob);
vault.withdraw(balance, bob);

assertEq(balance, bob.balance);
assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));
assertEq(0, vault.balanceOf(alice));
assertEq(0, vault.sharesOf(alice));
}

function test_Revert_TransferShares() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
vault.setFee(10_000);
vault.setFeeRecipient(feeRecipient);
vm.stopPrank();

address alice = vm.addr(100);
address bob = vm.addr(101);

vm.deal(alice, 100 ether);
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);
uint256 balance = vault.balanceOf(alice);

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));

vm.prank(bob);
vm.expectRevert(abi.encodeWithSelector(Vault.InsufficientBalance.selector, 0, shares));
vault.transferShares(bob, shares, "0x1234");

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));
assertEq(shares, vault.sharesOf(alice));
assertEq(balance, vault.balanceOf(alice));
}

function test_NotifyAfterSharesTransferred() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
vault.setFee(10_000);
vault.setFeeRecipient(feeRecipient);
vm.stopPrank();

MockSharesRecipient recipient = new MockSharesRecipient();

address alice = vm.addr(100);

vm.deal(alice, 100 ether);
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);

vm.prank(alice);
vm.expectEmit();
emit SharesTransferred(alice, address(recipient), shares, "0x1234");
vault.transferShares(address(recipient), shares, "0x1234");

assertEq(0, vault.sharesOf(alice));

assertEq(alice, recipient.lastFrom());
assertEq(address(vault), recipient.lastSender());
assertEq(shares, recipient.lastAmount());
assertEq("0x1234", recipient.lastData());
}
}

contract MockSharesRecipient is ERC165, IVaultSharesRecipient {
address public lastFrom;
uint256 public lastAmount;
bytes public lastData;
address public lastSender;

function supportsInterface(bytes4 interfaceId) public view override returns (bool) {
return interfaceId == type(IVaultSharesRecipient).interfaceId || super.supportsInterface(interfaceId);
}

function onSharesReceived(address from, uint256 amount, bytes calldata data) external override {
lastSender = msg.sender;
lastFrom = from;
lastAmount = amount;
lastData = data;
}
}

contract MockDepositContract is IDepositContract {
Expand Down

0 comments on commit ee17375

Please sign in to comment.