diff --git a/src/TokenEscrow.sol b/src/TokenEscrow.sol index fe6de9b..e48070f 100644 --- a/src/TokenEscrow.sol +++ b/src/TokenEscrow.sol @@ -15,19 +15,28 @@ import "@openzeppelin/contracts-upgradeable/utils/math/SafeMathUpgradeable.sol"; contract TokenEscrow is OwnableUpgradeable { using SafeMathUpgradeable for uint256; - event VestingScheduleAdded(address indexed user, uint256 amount, uint256 startTime, uint256 endTime, uint256 step); + event VestingScheduleAdded( + address indexed user, + uint256 startAmount, + uint256 vestingAmount, + uint256 startTime, + uint256 endTime, + uint256 step + ); event VestingScheduleRemoved(address indexed user); event TokenVested(address indexed user, uint256 amount); /** - * @param amount Total amount to be vested over the complete period + * @param startAmount Amount immediately available at the beginning of the whole schedule + * @param vestingAmount Amount to be vested over the complete period * @param startTime Unix timestamp in seconds for the period start time * @param endTime Unix timestamp in seconds for the period end time * @param step Interval in seconds at which vestable amounts are accumulated * @param lastClaimTime Unix timestamp in seconds for the last claim time */ struct VestingSchedule { - uint128 amount; + uint128 startAmount; + uint128 vestingAmount; uint32 startTime; uint32 endTime; uint32 step; @@ -38,9 +47,9 @@ contract TokenEscrow is OwnableUpgradeable { mapping(address => VestingSchedule) public vestingSchedules; function getWithdrawableAmount(address user) external view returns (uint256) { - (uint256 withdrawableFromVesting,,) = calculateWithdrawableFromVesting(user); + (uint256 withdrawableFromSchedule,,) = calculateWithdrawableFromSchedule(user); - return withdrawableFromVesting; + return withdrawableFromSchedule; } function __TokenEscrow_init(IERC20Upgradeable _token) public initializer { @@ -50,35 +59,47 @@ contract TokenEscrow is OwnableUpgradeable { token = _token; } - function setVestingSchedule(address user, uint256 amount, uint256 startTime, uint256 endTime, uint256 step) - external - onlyOwner - { + function setVestingSchedule( + address user, + uint256 startAmount, + uint256 vestingAmount, + uint256 startTime, + uint256 endTime, + uint256 step + ) external onlyOwner { require(user != address(0), "TokenEscrow: zero address"); - require(amount > 0, "TokenEscrow: zero amount"); + require(startAmount > 0 || vestingAmount > 0, "TokenEscrow: zero amount"); require(startTime < endTime, "TokenEscrow: invalid time range"); require(step > 0 && endTime.sub(startTime) % step == 0, "TokenEscrow: invalid step"); - require(vestingSchedules[user].amount == 0, "TokenEscrow: vesting schedule already exists"); + require( + vestingSchedules[user].startAmount == 0 && vestingSchedules[user].vestingAmount == 0, + "TokenEscrow: vesting schedule already exists" + ); // Overflow checks - require(uint256(uint128(amount)) == amount, "TokenEscrow: amount overflow"); + require(uint256(uint128(startAmount)) == startAmount, "TokenEscrow: startAmount overflow"); + require(uint256(uint128(vestingAmount)) == vestingAmount, "TokenEscrow: vestingAmount overflow"); require(uint256(uint32(startTime)) == startTime, "TokenEscrow: startTime overflow"); require(uint256(uint32(endTime)) == endTime, "TokenEscrow: endTime overflow"); require(uint256(uint32(step)) == step, "TokenEscrow: step overflow"); vestingSchedules[user] = VestingSchedule({ - amount: uint128(amount), + startAmount: uint128(startAmount), + vestingAmount: uint128(vestingAmount), startTime: uint32(startTime), endTime: uint32(endTime), step: uint32(step), - lastClaimTime: uint32(startTime) + lastClaimTime: 0 }); - emit VestingScheduleAdded(user, amount, startTime, endTime, step); + emit VestingScheduleAdded(user, startAmount, vestingAmount, startTime, endTime, step); } function removeVestingSchedule(address user) external onlyOwner { - require(vestingSchedules[user].amount != 0, "TokenEscrow: vesting schedule not set"); + require( + vestingSchedules[user].startAmount != 0 || vestingSchedules[user].vestingAmount != 0, + "TokenEscrow: vesting schedule not set" + ); delete vestingSchedules[user]; @@ -86,15 +107,15 @@ contract TokenEscrow is OwnableUpgradeable { } function withdraw() external { - uint256 withdrawableFromVesting; + uint256 withdrawableFromSchedule; - // Withdraw from vesting + // Withdraw from schedule { uint256 newClaimTime; bool allVested; - (withdrawableFromVesting, newClaimTime, allVested) = calculateWithdrawableFromVesting(msg.sender); + (withdrawableFromSchedule, newClaimTime, allVested) = calculateWithdrawableFromSchedule(msg.sender); - if (withdrawableFromVesting > 0) { + if (withdrawableFromSchedule > 0) { if (allVested) { // Remove storage slot to save gas delete vestingSchedules[msg.sender]; @@ -104,26 +125,29 @@ contract TokenEscrow is OwnableUpgradeable { } } - uint256 totalAmountToSend = withdrawableFromVesting; + uint256 totalAmountToSend = withdrawableFromSchedule; require(totalAmountToSend > 0, "TokenEscrow: nothing to withdraw"); - if (withdrawableFromVesting > 0) { - emit TokenVested(msg.sender, withdrawableFromVesting); + if (withdrawableFromSchedule > 0) { + emit TokenVested(msg.sender, withdrawableFromSchedule); } token.transfer(msg.sender, totalAmountToSend); } - function calculateWithdrawableFromVesting(address user) + function calculateWithdrawableFromSchedule(address user) private view returns (uint256 amount, uint256 newClaimTime, bool allVested) { VestingSchedule memory vestingSchedule = vestingSchedules[user]; - if (vestingSchedule.amount == 0) { + // Schedule not set? + if (vestingSchedule.startAmount == 0 && vestingSchedule.vestingAmount == 0) { return (0, 0, false); } + + // Schedule not started? if (block.timestamp < uint256(vestingSchedule.startTime)) { return (0, 0, false); } @@ -135,28 +159,41 @@ contract TokenEscrow is OwnableUpgradeable { uint256(vestingSchedule.endTime) ); - if (currentStepTime <= uint256(vestingSchedule.lastClaimTime)) { - return (0, 0, false); - } - - uint256 totalSteps = - uint256(vestingSchedule.endTime).sub(uint256(vestingSchedule.startTime)).div(vestingSchedule.step); - - if (currentStepTime == uint256(vestingSchedule.endTime)) { - // All vested + uint256 amountFromStart = + vestingSchedule.lastClaimTime >= vestingSchedule.startTime ? 0 : vestingSchedule.startAmount; - uint256 stepsVested = - uint256(vestingSchedule.lastClaimTime).sub(uint256(vestingSchedule.startTime)).div(vestingSchedule.step); - uint256 amountToVest = - uint256(vestingSchedule.amount).sub(uint256(vestingSchedule.amount).div(totalSteps).mul(stepsVested)); + uint256 amountFromVesting; + { + uint256 effectiveLastClaimTime = + MathUpgradeable.max(uint256(vestingSchedule.lastClaimTime), uint256(vestingSchedule.startTime)); + + if (currentStepTime <= effectiveLastClaimTime) { + amountFromVesting = 0; + } else { + uint256 totalSteps = + uint256(vestingSchedule.endTime).sub(uint256(vestingSchedule.startTime)).div(vestingSchedule.step); + + if (currentStepTime == uint256(vestingSchedule.endTime)) { + // All vested + + uint256 stepsVested = + effectiveLastClaimTime.sub(uint256(vestingSchedule.startTime)).div(vestingSchedule.step); + amountFromVesting = uint256(vestingSchedule.vestingAmount).sub( + uint256(vestingSchedule.vestingAmount).div(totalSteps).mul(stepsVested) + ); + } else { + // Partially vested + uint256 stepsToVest = currentStepTime.sub(effectiveLastClaimTime).div(vestingSchedule.step); + amountFromVesting = uint256(vestingSchedule.vestingAmount).div(totalSteps).mul(stepsToVest); + } + } + } - return (amountToVest, currentStepTime, true); + uint256 totalAmount = amountFromStart + amountFromVesting; + if (totalAmount > 0) { + return (totalAmount, currentStepTime, currentStepTime == uint256(vestingSchedule.endTime)); } else { - // Partially vested - uint256 stepsToVest = currentStepTime.sub(uint256(vestingSchedule.lastClaimTime)).div(vestingSchedule.step); - uint256 amountToVest = uint256(vestingSchedule.amount).div(totalSteps).mul(stepsToVest); - - return (amountToVest, currentStepTime, false); + return (0, 0, false); } } } diff --git a/test/TokenEscrow.t.sol b/test/TokenEscrow.t.sol index 9ff9b6d..780d19c 100644 --- a/test/TokenEscrow.t.sol +++ b/test/TokenEscrow.t.sol @@ -9,7 +9,14 @@ import "test/mock/MockToken.sol"; contract TokenEscrowTest is Test { event TokenVested(address indexed user, uint256 amount); event Transfer(address indexed from, address indexed to, uint256 value); - event VestingScheduleAdded(address indexed user, uint256 amount, uint256 startTime, uint256 endTime, uint256 step); + event VestingScheduleAdded( + address indexed user, + uint256 startAmount, + uint256 vestingAmount, + uint256 startTime, + uint256 endTime, + uint256 step + ); MockERC20 erc20Token; TokenEscrow tokenEscrow; @@ -61,7 +68,8 @@ contract TokenEscrowTest is Test { vm.prank(ALICE); tokenEscrow.setVestingSchedule( BOB, // user - 1, // amount + 0, // startAmount + 1, // vestingAmount 2, // startTime 10, // endTime 2 // step @@ -70,7 +78,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 1, // amount + 0, // startAmount + 1, // vestingAmount 2, // startTime 10, // endTime 2 // step @@ -81,13 +90,14 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 1, // amount + 0, // startAmount + 1, // vestingAmount 2, // startTime 10, // endTime 2 // step ); - (uint128 amount,,,,) = tokenEscrow.vestingSchedules(BOB); - assertNotEq(amount, 0); + (, uint128 vestingAmount,,,,) = tokenEscrow.vestingSchedules(BOB); + assertNotEq(vestingAmount, 0); vm.expectRevert(bytes("Ownable: caller is not the owner")); vm.prank(ALICE); @@ -99,16 +109,28 @@ contract TokenEscrowTest is Test { tokenEscrow.removeVestingSchedule( BOB // user ); - (amount,,,,) = tokenEscrow.vestingSchedules(BOB); - assertEq(amount, 0); + (, vestingAmount,,,,) = tokenEscrow.vestingSchedules(BOB); + assertEq(vestingAmount, 0); } function testVestingScheduleParamsMustNotOverflow() public { - vm.expectRevert(bytes("TokenEscrow: amount overflow")); + vm.expectRevert(bytes("TokenEscrow: startAmount overflow")); vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 0xffffffffffffffffffffffffffffffff + 1, // amount + 0xffffffffffffffffffffffffffffffff + 1, // startAmount + 0, // vestingAmount + 2, // startTime + 10, // endTime + 1 // step + ); + + vm.expectRevert(bytes("TokenEscrow: vestingAmount overflow")); + vm.prank(OWNER); + tokenEscrow.setVestingSchedule( + BOB, // user + 0, // startAmount + 0xffffffffffffffffffffffffffffffff + 1, // vestingAmount 2, // startTime 10, // endTime 1 // step @@ -118,7 +140,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 1, // amount + 0, // startAmount + 1, // vestingAmount 0xffffffff + 1, // startTime 0xffffffff + 2, // endTime 1 // step @@ -128,7 +151,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 1, // amount + 0, // startAmount + 1, // vestingAmount 2, // startTime 0xffffffff + 1, // endTime 1 // step @@ -139,7 +163,8 @@ contract TokenEscrowTest is Test { vm.expectEmit(address(tokenEscrow)); emit VestingScheduleAdded( BOB, // user - 100, // amount + 0, // startAmount + 100, // vestingAmount 200, // startTime 300, // endTime 50 // step @@ -147,7 +172,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 100, // amount + 0, // startAmount + 100, // vestingAmount 200, // startTime 300, // endTime 50 // step @@ -157,7 +183,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( BOB, // user - 100, // amount + 0, // startAmount + 100, // vestingAmount 200, // startTime 300, // endTime 50 // step @@ -171,7 +198,8 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( ALICE, // user - 10000, // amount + 0, // startAmount + 10000, // vestingAmount startTime, // startTime startTime + step * 10, // endTime step // step @@ -213,21 +241,22 @@ contract TokenEscrowTest is Test { vm.prank(OWNER); tokenEscrow.setVestingSchedule( ALICE, // user - 10000, // amount + 555, // startAmount + 10000, // vestingAmount startTime, // startTime startTime + step * 10, // endTime step // step ); - // Claim 3 steps at once + // Claim 3 steps at once (including start amount) vm.warp(startTime + step * 3); vm.expectEmit(address(tokenEscrow)); - emit TokenVested(ALICE, 3000); + emit TokenVested(ALICE, 3555); vm.expectEmit(address(erc20Token)); emit Transfer( address(tokenEscrow), // sender ALICE, // recipient - 3000 // amount + 3555 // amount ); vm.prank(ALICE); tokenEscrow.withdraw();