Skip to content

Commit

Permalink
Switch to transfer stake, not shares
Browse files Browse the repository at this point in the history
  • Loading branch information
lykhonis committed Mar 26, 2024
1 parent f1a36f3 commit 136c6db
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 43 deletions.
35 changes: 21 additions & 14 deletions src/pool/Vault.sol
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ interface IVault {
event Rebalanced(
uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked
);
event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data);
event StakeTransferred(address indexed from, address indexed to, uint256 amount, bytes data);

function depositLimit() external view returns (uint256);
function totalAssets() external view returns (uint256);
Expand All @@ -44,15 +44,15 @@ interface IVault {
function withdraw(uint256 amount, address beneficiary) external;
function claim(uint256 amount, address beneficiary) external;
function claimFees(uint256 amount, address beneficiary) external;
function transferShares(address to, uint256 amount, bytes calldata data) external;
function transferStake(address to, uint256 amount, bytes calldata data) external;
}

interface IVaultSharesRecipient {
/// @notice Shares have been transfered to the recipient.
interface IVaultStakeRecipient {
/// @notice Amount of stake have been transfered to the recipient.
/// @param from The address of the sender.
/// @param amount The amount of shares.
/// @param amount The amount of stake.
/// @param data Additional data.
function onSharesReceived(address from, uint256 amount, bytes calldata data) external;
function onVaultStakeReceived(address from, uint256 amount, bytes calldata data) external;
}

contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, PausableUpgradeable {
Expand Down Expand Up @@ -479,7 +479,7 @@ contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, Paus
}
}

function transferShares(address to, uint256 amount, bytes calldata data)
function transferStake(address to, uint256 amount, bytes calldata data)
external
override
whenNotPaused
Expand All @@ -489,17 +489,24 @@ contract Vault is IVault, ERC165, OwnableUnset, ReentrancyGuardUpgradeable, Paus
if (amount == 0) {
revert InvalidAmount(amount);
}
if (amount > _shares[account]) {
revert InsufficientBalance(_shares[account], amount);
if (amount > balanceOf(account)) {
revert InsufficientBalance(balanceOf(account), amount);
}
uint256 shares = _toShares(amount);
if (shares == 0) {
revert InvalidAmount(shares);
}
if (shares > _shares[account]) {
revert InsufficientBalance(_shares[account], shares);
}
if (to == address(0)) {
revert InvalidAddress(to);
}
_shares[account] -= amount;
_shares[to] += amount;
emit SharesTransferred(account, to, amount, data);
if (ERC165Checker.supportsInterface(to, type(IVaultSharesRecipient).interfaceId)) {
IVaultSharesRecipient(to).onSharesReceived(account, amount, data);
_shares[account] -= shares;
_shares[to] += shares;
emit StakeTransferred(account, to, amount, data);
if (ERC165Checker.supportsInterface(to, type(IVaultStakeRecipient).interfaceId)) {
IVaultStakeRecipient(to).onVaultStakeReceived(account, amount, data);
}
}
}
48 changes: 19 additions & 29 deletions test/pool/Vault.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {OwnableCallerNotTheOwner} from "@erc725/smart-contracts/contracts/errors
import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
import {ERC165} from "openzeppelin-contracts/contracts/utils/introspection/ERC165.sol";
import {ERC165Checker} from "openzeppelin-contracts/contracts/utils/introspection/ERC165Checker.sol";
import {IVault, Vault, IVaultSharesRecipient} from "../../src/pool/Vault.sol";
import {IVault, Vault, IVaultStakeRecipient} from "../../src/pool/Vault.sol";
import {IDepositContract} from "../../src/pool/IDepositContract.sol";

contract VaultTest is Test {
Expand All @@ -26,7 +26,7 @@ contract VaultTest is Test {
uint256 previousTotalStaked, uint256 previousTotalUnstaked, uint256 totalStaked, uint256 totalUnstaked
);
event ValidatorExited(bytes pubkey, uint256 total);
event SharesTransferred(address indexed from, address indexed to, uint256 amount, bytes data);
event StakeTransferred(address indexed from, address indexed to, uint256 amount, bytes data);

Vault vault;
address admin;
Expand Down Expand Up @@ -1595,7 +1595,7 @@ contract VaultTest is Test {
);
}

function test_TransferShares() public {
function test_TransferStake() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
Expand All @@ -1610,33 +1610,27 @@ contract VaultTest is Test {
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);
uint256 balance = vault.balanceOf(alice);

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));

vm.prank(alice);
vm.expectEmit();
emit SharesTransferred(alice, bob, shares, "0x1234");
vault.transferShares(bob, shares, "0x1234");
emit StakeTransferred(alice, bob, balance, "0x1234");
vault.transferStake(bob, balance, "0x1234");

assertEq(shares, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(alice));
assertEq(0, vault.sharesOf(alice));
assertEq(balance, vault.balanceOf(bob));
assertEq(0, bob.balance);

vm.prank(bob);
vault.withdraw(balance, bob);

assertEq(balance, bob.balance);
assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));
assertEq(0, vault.balanceOf(alice));
assertEq(0, vault.sharesOf(alice));
}

function test_Revert_TransferShares() public {
function test_Revert_TransferStake() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
Expand All @@ -1651,65 +1645,61 @@ contract VaultTest is Test {
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);
uint256 balance = vault.balanceOf(alice);

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));

vm.prank(bob);
vm.expectRevert(abi.encodeWithSelector(Vault.InsufficientBalance.selector, 0, shares));
vault.transferShares(bob, shares, "0x1234");
vm.expectRevert(abi.encodeWithSelector(Vault.InsufficientBalance.selector, 0, balance));
vault.transferStake(bob, balance, "0x1234");

assertEq(0, vault.sharesOf(bob));
assertEq(0, vault.balanceOf(bob));
assertEq(shares, vault.sharesOf(alice));
assertEq(balance, vault.balanceOf(alice));
}

function test_NotifyAfterSharesTransferred() public {
function test_NotifyAfterStakeTransferred() public {
vm.startPrank(owner);
vault.setDepositLimit(1_000_000 ether);
vault.enableOracle(oracle, true);
vault.setFee(10_000);
vault.setFeeRecipient(feeRecipient);
vm.stopPrank();

MockSharesRecipient recipient = new MockSharesRecipient();
MockStakeRecipient recipient = new MockStakeRecipient();

address alice = vm.addr(100);

vm.deal(alice, 100 ether);
vm.prank(alice);
vault.deposit{value: 100 ether}(alice);

uint256 shares = vault.sharesOf(alice);
uint256 balance = vault.balanceOf(alice);

vm.prank(alice);
vm.expectEmit();
emit SharesTransferred(alice, address(recipient), shares, "0x1234");
vault.transferShares(address(recipient), shares, "0x1234");
emit StakeTransferred(alice, address(recipient), balance, "0x1234");
vault.transferStake(address(recipient), balance, "0x1234");

assertEq(0, vault.sharesOf(alice));
assertEq(0, vault.balanceOf(alice));

assertEq(alice, recipient.lastFrom());
assertEq(address(vault), recipient.lastSender());
assertEq(shares, recipient.lastAmount());
assertEq(balance, recipient.lastAmount());
assertEq("0x1234", recipient.lastData());
}
}

contract MockSharesRecipient is ERC165, IVaultSharesRecipient {
contract MockStakeRecipient is ERC165, IVaultStakeRecipient {
address public lastFrom;
uint256 public lastAmount;
bytes public lastData;
address public lastSender;

function supportsInterface(bytes4 interfaceId) public view override returns (bool) {
return interfaceId == type(IVaultSharesRecipient).interfaceId || super.supportsInterface(interfaceId);
return interfaceId == type(IVaultStakeRecipient).interfaceId || super.supportsInterface(interfaceId);
}

function onSharesReceived(address from, uint256 amount, bytes calldata data) external override {
function onVaultStakeReceived(address from, uint256 amount, bytes calldata data) external override {
lastSender = msg.sender;
lastFrom = from;
lastAmount = amount;
Expand Down

0 comments on commit 136c6db

Please sign in to comment.