diff --git a/src/pool/Vault.sol b/src/pool/Vault.sol index 15dc64e..ee92611 100644 --- a/src/pool/Vault.sol +++ b/src/pool/Vault.sol @@ -23,7 +23,7 @@ interface IVault { event Rebalanced( uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked ); - event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data); + event StakeTransferred(address indexed from, address indexed to, uint256 amount, bytes data); function depositLimit() external view returns (uint256); function totalAssets() external view returns (uint256); @@ -44,15 +44,15 @@ interface IVault { function withdraw(uint256 amount, address beneficiary) external; function claim(uint256 amount, address beneficiary) external; function claimFees(uint256 amount, address beneficiary) external; - function transferShares(address to, uint256 amount, bytes calldata data) external; + function transferStake(address to, uint256 amount, bytes calldata data) external; } -interface IVaultSharesRecipient { - /// @notice Shares have been transfered to the recipient. +interface IVaultStakeRecipient { + /// @notice Amount of stake have been transfered to the recipient. /// @param from The address of the sender. - /// @param amount The amount of shares. + /// @param amount The amount of stake. /// @param data Additional data. - function onSharesReceived(address from, uint256 amount, bytes calldata data) external; + function onVaultStakeReceived(address from, uint256 amount, bytes calldata data) external; } contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable { @@ -479,7 +479,7 @@ contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, Paus } } - function transferShares(address to, uint256 amount, bytes calldata data) + function transferStake(address to, uint256 amount, bytes calldata data) external override whenNotPaused @@ -489,17 +489,24 @@ contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, Paus if (amount == 0) { revert InvalidAmount(amount); } - if (amount > _shares[account]) { - revert InsufficientBalance(_shares[account], amount); + if (amount > balanceOf(account)) { + revert InsufficientBalance(balanceOf(account), amount); + } + uint256 shares = _toShares(amount); + if (shares == 0) { + revert InvalidAmount(shares); + } + if (shares > _shares[account]) { + revert InsufficientBalance(_shares[account], shares); } if (to == address(0)) { revert InvalidAddress(to); } - _shares[account] -= amount; - _shares[to] += amount; - emit SharesTransferred(account, to, amount, data); - if (ERC165Checker.supportsInterface(to, type(IVaultSharesRecipient).interfaceId)) { - IVaultSharesRecipient(to).onSharesReceived(account, amount, data); + _shares[account] -= shares; + _shares[to] += shares; + emit StakeTransferred(account, to, amount, data); + if (ERC165Checker.supportsInterface(to, type(IVaultStakeRecipient).interfaceId)) { + IVaultStakeRecipient(to).onVaultStakeReceived(account, amount, data); } } } diff --git a/test/pool/Vault.t.sol b/test/pool/Vault.t.sol index a8f43a8..7fb78e7 100644 --- a/test/pool/Vault.t.sol +++ b/test/pool/Vault.t.sol @@ -7,7 +7,7 @@ import {OwnableCallerNotTheOwner} from "@erc725/smart-contracts/contracts/errors import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; import {ERC165} from "openzeppelin-contracts/contracts/utils/introspection/ERC165.sol"; import {ERC165Checker} from "openzeppelin-contracts/contracts/utils/introspection/ERC165Checker.sol"; -import {IVault, Vault, IVaultSharesRecipient} from "../../src/pool/Vault.sol"; +import {IVault, Vault, IVaultStakeRecipient} from "../../src/pool/Vault.sol"; import {IDepositContract} from "../../src/pool/IDepositContract.sol"; contract VaultTest is Test { @@ -26,7 +26,7 @@ contract VaultTest is Test { uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked ); event ValidatorExited(bytes pubkey, uint256 total); - event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data); + event StakeTransferred(address indexed from, address indexed to, uint256 amount, bytes data); Vault vault; address admin; @@ -1595,7 +1595,7 @@ contract VaultTest is Test { ); } - function test_TransferShares() public { + function test_TransferStake() public { vm.startPrank(owner); vault.setDepositLimit(1_000_000 ether); vault.enableOracle(oracle, true); @@ -1610,33 +1610,27 @@ contract VaultTest is Test { vm.prank(alice); vault.deposit{value: 100 ether}(alice); - uint256 shares = vault.sharesOf(alice); uint256 balance = vault.balanceOf(alice); - - assertEq(0, vault.sharesOf(bob)); assertEq(0, vault.balanceOf(bob)); vm.prank(alice); vm.expectEmit(); - emit SharesTransferred(alice, bob, shares, "0x1234"); - vault.transferShares(bob, shares, "0x1234"); + emit StakeTransferred(alice, bob, balance, "0x1234"); + vault.transferStake(bob, balance, "0x1234"); - assertEq(shares, vault.sharesOf(bob)); assertEq(0, vault.balanceOf(alice)); - assertEq(0, vault.sharesOf(alice)); + assertEq(balance, vault.balanceOf(bob)); assertEq(0, bob.balance); vm.prank(bob); vault.withdraw(balance, bob); assertEq(balance, bob.balance); - assertEq(0, vault.sharesOf(bob)); assertEq(0, vault.balanceOf(bob)); assertEq(0, vault.balanceOf(alice)); - assertEq(0, vault.sharesOf(alice)); } - function test_Revert_TransferShares() public { + function test_Revert_TransferStake() public { vm.startPrank(owner); vault.setDepositLimit(1_000_000 ether); vault.enableOracle(oracle, true); @@ -1651,23 +1645,18 @@ contract VaultTest is Test { vm.prank(alice); vault.deposit{value: 100 ether}(alice); - uint256 shares = vault.sharesOf(alice); uint256 balance = vault.balanceOf(alice); - - assertEq(0, vault.sharesOf(bob)); assertEq(0, vault.balanceOf(bob)); vm.prank(bob); - vm.expectRevert(abi.encodeWithSelector(Vault.InsufficientBalance.selector, 0, shares)); - vault.transferShares(bob, shares, "0x1234"); + vm.expectRevert(abi.encodeWithSelector(Vault.InsufficientBalance.selector, 0, balance)); + vault.transferStake(bob, balance, "0x1234"); - assertEq(0, vault.sharesOf(bob)); assertEq(0, vault.balanceOf(bob)); - assertEq(shares, vault.sharesOf(alice)); assertEq(balance, vault.balanceOf(alice)); } - function test_NotifyAfterSharesTransferred() public { + function test_NotifyAfterStakeTransferred() public { vm.startPrank(owner); vault.setDepositLimit(1_000_000 ether); vault.enableOracle(oracle, true); @@ -1675,7 +1664,7 @@ contract VaultTest is Test { vault.setFeeRecipient(feeRecipient); vm.stopPrank(); - MockSharesRecipient recipient = new MockSharesRecipient(); + MockStakeRecipient recipient = new MockStakeRecipient(); address alice = vm.addr(100); @@ -1683,33 +1672,34 @@ contract VaultTest is Test { vm.prank(alice); vault.deposit{value: 100 ether}(alice); - uint256 shares = vault.sharesOf(alice); + uint256 balance = vault.balanceOf(alice); vm.prank(alice); vm.expectEmit(); - emit SharesTransferred(alice, address(recipient), shares, "0x1234"); - vault.transferShares(address(recipient), shares, "0x1234"); + emit StakeTransferred(alice, address(recipient), balance, "0x1234"); + vault.transferStake(address(recipient), balance, "0x1234"); assertEq(0, vault.sharesOf(alice)); + assertEq(0, vault.balanceOf(alice)); assertEq(alice, recipient.lastFrom()); assertEq(address(vault), recipient.lastSender()); - assertEq(shares, recipient.lastAmount()); + assertEq(balance, recipient.lastAmount()); assertEq("0x1234", recipient.lastData()); } } -contract MockSharesRecipient is ERC165, IVaultSharesRecipient { +contract MockStakeRecipient is ERC165, IVaultStakeRecipient { address public lastFrom; uint256 public lastAmount; bytes public lastData; address public lastSender; function supportsInterface(bytes4 interfaceId) public view override returns (bool) { - return interfaceId == type(IVaultSharesRecipient).interfaceId || super.supportsInterface(interfaceId); + return interfaceId == type(IVaultStakeRecipient).interfaceId || super.supportsInterface(interfaceId); } - function onSharesReceived(address from, uint256 amount, bytes calldata data) external override { + function onVaultStakeReceived(address from, uint256 amount, bytes calldata data) external override { lastSender = msg.sender; lastFrom = from; lastAmount = amount;