From 17bada2d7b8810f10216dc833760e86ec35e5865 Mon Sep 17 00:00:00 2001 From: Bilgehan Sahin Date: Thu, 16 Mar 2023 22:34:05 -0700 Subject: [PATCH] only lookup staked account balance when needed --- service/backend/pchain/account.go | 27 +++++++------ service/backend/pchain/account_test.go | 55 ++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/service/backend/pchain/account.go b/service/backend/pchain/account.go index af188669..aa0548a5 100644 --- a/service/backend/pchain/account.go +++ b/service/backend/pchain/account.go @@ -64,8 +64,9 @@ func (b *Backend) AccountBalance(ctx context.Context, req *types.AccountBalanceR } fetchImportable := balanceType == pmapper.SubAccountTypeSharedMemory + fetchStaked := balanceType == pmapper.SubAccountTypeStaked || balanceType == "" - height, balance, typedErr := b.fetchBalance(ctx, req.AccountIdentifier.Address, fetchImportable, currencyAssetIDs) + height, balance, typedErr := b.fetchBalance(ctx, req.AccountIdentifier.Address, fetchStaked, fetchImportable, currencyAssetIDs) if typedErr != nil { return nil, typedErr } @@ -255,18 +256,18 @@ func (b *Backend) getPendingRewardsBalance(ctx context.Context, req *types.Accou }, nil } -func (b *Backend) fetchBalance(ctx context.Context, addrString string, fetchImportable bool, assetIds set.Set[ids.ID]) (uint64, *AccountBalance, *types.Error) { +func (b *Backend) fetchBalance(ctx context.Context, addrString string, fetchStaked bool, fetchImportable bool, assetIds set.Set[ids.ID]) (uint64, *AccountBalance, *types.Error) { addr, err := address.ParseToID(addrString) if err != nil { return 0, nil, service.WrapError(service.ErrInvalidInput, "unable to convert address") } - // utxos from fetchUTXOsAndStakedOutputs are guarateed to: + // utxos from fetchUTXOsAndStakedOutputs are guaranteed to: // 1. be unique (no duplicates) - // 2. containt only assetIDs + // 2. contain only assetIDs // 3. have not multisign utxos // by parseAndFilterUTXOs call in fetchUTXOsAndStakedOutputs - height, utxos, stakedUTXOBytes, typedErr := b.fetchUTXOsAndStakedOutputs(ctx, addr, !fetchImportable, fetchImportable, assetIds) + height, utxos, stakedUTXOBytes, typedErr := b.fetchUTXOsAndStakedOutputs(ctx, addr, fetchStaked, fetchImportable, assetIds) if typedErr != nil { return 0, nil, typedErr } @@ -276,14 +277,16 @@ func (b *Backend) fetchBalance(ctx context.Context, addrString string, fetchImpo return 0, nil, service.WrapError(service.ErrInternalError, err) } - // parse staked UTXO bytes to UTXO structs - stakedAmount, err := b.calculateStakedAmount(stakedUTXOBytes) - if err != nil { - return 0, nil, service.WrapError(service.ErrInternalError, err) - } + if fetchStaked { + // parse staked UTXO bytes to UTXO structs + stakedAmount, err := b.calculateStakedAmount(stakedUTXOBytes) + if err != nil { + return 0, nil, service.WrapError(service.ErrInternalError, err) + } - balance.Staked = stakedAmount - balance.Total += stakedAmount + balance.Staked = stakedAmount + balance.Total += stakedAmount + } return height, balance, nil } diff --git a/service/backend/pchain/account_test.go b/service/backend/pchain/account_test.go index 41acfa4d..16e78290 100644 --- a/service/backend/pchain/account_test.go +++ b/service/backend/pchain/account_test.go @@ -118,6 +118,61 @@ func TestAccountBalance(t *testing.T) { parserMock.AssertExpectations(t) }) + t.Run("Account Balance should not call GetStake if sub account type is not staked or empty", func(t *testing.T) { + addr, _ := address.ParseToID(pChainAddr) + utxo0Bytes := makeUtxoBytes(t, backend, utxos[0].id, utxos[0].amount) + utxo1Bytes := makeUtxoBytes(t, backend, utxos[1].id, utxos[1].amount) + utxo1Id, _ := ids.FromString(utxos[1].id) + + // Mock on GetAssetDescription + pChainMock.Mock.On("GetAssetDescription", ctx, mapper.AtomicAvaxCurrency.Symbol).Return(mockAssetDescription, nil) + + // once before other calls, once after + pChainMock.Mock.On("GetHeight", ctx).Return(blockHeight, nil).Twice() + // Make sure pagination works as well + pageSize := uint32(2) + backend.getUTXOsPageSize = pageSize + pChainMock.Mock.On("GetAtomicUTXOs", ctx, []ids.ShortID{addr}, "", pageSize, ids.ShortEmpty, ids.Empty). + Return([][]byte{utxo0Bytes, utxo1Bytes}, addr, utxo1Id, nil).Once() + pChainMock.Mock.On("GetAtomicUTXOs", ctx, []ids.ShortID{addr}, "", pageSize, addr, utxo1Id). + Return([][]byte{utxo1Bytes}, addr, utxo1Id, nil).Once() + + resp, err := backend.AccountBalance( + ctx, + &types.AccountBalanceRequest{ + NetworkIdentifier: &types.NetworkIdentifier{ + Network: constants.FujiNetwork, + SubNetworkIdentifier: &types.SubNetworkIdentifier{ + Network: constants.PChain.String(), + }, + }, + AccountIdentifier: &types.AccountIdentifier{ + Address: pChainAddr, + SubAccount: &types.SubAccountIdentifier{ + Address: pmapper.SubAccountTypeUnlocked, + }, + }, + Currencies: []*types.Currency{ + mapper.AtomicAvaxCurrency, + }, + }, + ) + + expected := &types.AccountBalanceResponse{ + Balances: []*types.Amount{ + { + Value: "3000000000", // 1B + 2B from UTXOs + Currency: mapper.AtomicAvaxCurrency, + }, + }, + } + + assert.Nil(t, err) + assert.Equal(t, expected.Balances, resp.Balances) + pChainMock.AssertExpectations(t) + parserMock.AssertExpectations(t) + }) + t.Run("Account Balance should return total of shared memory balance", func(t *testing.T) { // Mock on GetUTXOs utxo0Bytes := makeUtxoBytes(t, backend, utxos[0].id, utxos[0].amount)