diff --git a/.js.env.example b/.js.env.example index 68e8316f034..81d9fea44a1 100644 --- a/.js.env.example +++ b/.js.env.example @@ -90,7 +90,10 @@ export MM_ENABLE_SETTINGS_PAGE_DEV_OPTIONS="true" # The endpoint used to submit errors and tracing data to Sentry for dev environment. # export MM_SENTRY_DSN_DEV= -# Multichain Feature flag -export MULTICHAIN_V1="" +# Per dapp selected network (Amon Hen) feature flag +export MM_PER_DAPP_SELECTED_NETWORK="" + +export MM_CHAIN_PERMISSIONS="" + #Multichain feature flag specific to UI changes export MM_MULTICHAIN_V1_ENABLED="" diff --git a/app/components/UI/PermissionsSummary/PermissionsSummary.test.tsx b/app/components/UI/PermissionsSummary/PermissionsSummary.test.tsx index ebc79900213..5ce8f2a4f7b 100644 --- a/app/components/UI/PermissionsSummary/PermissionsSummary.test.tsx +++ b/app/components/UI/PermissionsSummary/PermissionsSummary.test.tsx @@ -29,6 +29,24 @@ const mockInitialState = { }; describe('PermissionsSummary', () => { + it('should render correctly for network switch', () => { + const { toJSON } = renderWithProvider( + , + { state: mockInitialState }, + ); + expect(toJSON()).toMatchSnapshot(); + }); it('should render correctly', () => { const { toJSON } = renderWithProvider( { onUserAction?.(USER_INTENT.Confirm); + onConfirm?.(); }; const cancel = () => { onUserAction?.(USER_INTENT.Cancel); + onCancel?.(); }; const handleEditAccountsButtonPress = () => { @@ -208,21 +227,33 @@ const PermissionsSummary = ({ {strings('permissions.use_enabled_networks')} - - - - {strings('permissions.requesting_for')} - - - {networkName} - - - - - - + {isNetworkSwitch && ( + <> + + + + {strings('permissions.requesting_for')} + + + {chainName} + + + + + + )} + {!isNetworkSwitch && ( + + + + )} {!isNetworkSwitch && renderEndAccessory()} @@ -247,6 +278,7 @@ const PermissionsSummary = ({ })} + {/*TODO These should be conditional upon which permissions are being requested*/} {!isNetworkSwitch && renderAccountPermissionsRequestInfoCard()} {renderNetworkPermissionsRequestInfoCard()} diff --git a/app/components/UI/PermissionsSummary/PermissionsSummary.types.ts b/app/components/UI/PermissionsSummary/PermissionsSummary.types.ts index c80d27e198a..1be07727a7d 100644 --- a/app/components/UI/PermissionsSummary/PermissionsSummary.types.ts +++ b/app/components/UI/PermissionsSummary/PermissionsSummary.types.ts @@ -9,10 +9,16 @@ export interface PermissionsSummaryProps { onEdit?: () => void; onEditNetworks?: () => void; onBack?: () => void; + onCancel?: () => void; + onConfirm?: () => void; onUserAction?: React.Dispatch>; showActionButtons?: boolean; isAlreadyConnected?: boolean; isRenderedAsBottomSheet?: boolean; isDisconnectAllShown?: boolean; isNetworkSwitch?: boolean; + customNetworkInformation?: { + chainName: string; + chainId: string; + }; } diff --git a/app/components/UI/PermissionsSummary/__snapshots__/PermissionsSummary.test.tsx.snap b/app/components/UI/PermissionsSummary/__snapshots__/PermissionsSummary.test.tsx.snap index 07ce968d48c..b6d0fea2047 100644 --- a/app/components/UI/PermissionsSummary/__snapshots__/PermissionsSummary.test.tsx.snap +++ b/app/components/UI/PermissionsSummary/__snapshots__/PermissionsSummary.test.tsx.snap @@ -500,60 +500,6 @@ exports[`PermissionsSummary should render correctly 1`] = ` } } > - - - - Requesting for - - - Ethereum Main Network - - - `; + +exports[`PermissionsSummary should render correctly for network switch 1`] = ` + + + + + + + + + + a + + + + + + + + + app.uniswap.org wants to: + + + + + + + + + + Use your enabled networks + + + + + + Requesting for + + + Sepolia + + + + + + + + + + + + + + + + + Disconnect all + + + + + + + Cancel + + + + + Confirm + + + + + + +`; diff --git a/app/components/Views/AccountConnect/AccountConnect.tsx b/app/components/Views/AccountConnect/AccountConnect.tsx index 9d3747aed3a..9883878435c 100644 --- a/app/components/Views/AccountConnect/AccountConnect.tsx +++ b/app/components/Views/AccountConnect/AccountConnect.tsx @@ -345,7 +345,6 @@ const AccountConnect = (props: AccountConnectProps) => { }, approvedAccounts: selectedAddresses, }; - const connectedAccountLength = selectedAddresses.length; const activeAddress = selectedAddresses[0]; const activeAccountName = getAccountNameWithENS({ diff --git a/app/core/Engine.ts b/app/core/Engine.ts index a0bbf9344ed..15d700d6746 100644 --- a/app/core/Engine.ts +++ b/app/core/Engine.ts @@ -981,6 +981,10 @@ class Engine { caveatSpecifications: getCaveatSpecifications({ getInternalAccounts: accountsController.listAccounts.bind(accountsController), + findNetworkClientIdByChainId: + networkController.findNetworkClientIdByChainId.bind( + networkController, + ), }), // @ts-expect-error Typecast permissionType from getPermissionSpecifications to be of type PermissionType.RestrictedMethod permissionSpecifications: { diff --git a/app/core/Permissions/constants.ts b/app/core/Permissions/constants.ts index 4ee3ab20c49..95f22e51597 100644 --- a/app/core/Permissions/constants.ts +++ b/app/core/Permissions/constants.ts @@ -1,5 +1,6 @@ export const CaveatTypes = Object.freeze({ restrictReturnedAccounts: 'restrictReturnedAccounts', + restrictNetworkSwitching: 'restrictNetworkSwitching', }); export const RestrictedMethods = Object.freeze({ diff --git a/app/core/Permissions/specifications.js b/app/core/Permissions/specifications.js index e20588dc9b7..6d3b0dca2f1 100644 --- a/app/core/Permissions/specifications.js +++ b/app/core/Permissions/specifications.js @@ -21,19 +21,24 @@ import { CaveatTypes, RestrictedMethods } from './constants'; * The "keys" of all of permissions recognized by the PermissionController. * Permission keys and names have distinct meanings in the permission system. */ -const PermissionKeys = Object.freeze({ +export const PermissionKeys = Object.freeze({ ...RestrictedMethods, + permittedChains: 'endowment:permitted-chains', }); /** * Factory functions for all caveat types recognized by the * PermissionController. */ -const CaveatFactories = Object.freeze({ +export const CaveatFactories = Object.freeze({ [CaveatTypes.restrictReturnedAccounts]: (accounts) => ({ type: CaveatTypes.restrictReturnedAccounts, value: accounts, }), + [CaveatTypes.restrictNetworkSwitching]: (chainIds) => ({ + type: CaveatTypes.restrictNetworkSwitching, + value: chainIds, + }), }); /** @@ -52,9 +57,13 @@ const CaveatFactories = Object.freeze({ * * @param {{ * getInternalAccounts: () => import('@metamask/keyring-api').InternalAccount[], + * findNetworkClientIdByChainId: (chainId: `0x${string}`) => string, * }} options - Options bag. */ -export const getCaveatSpecifications = ({ getInternalAccounts }) => ({ +export const getCaveatSpecifications = ({ + getInternalAccounts, + findNetworkClientIdByChainId, +}) => ({ [CaveatTypes.restrictReturnedAccounts]: { type: CaveatTypes.restrictReturnedAccounts, @@ -74,6 +83,21 @@ export const getCaveatSpecifications = ({ getInternalAccounts }) => ({ validator: (caveat, _origin, _target) => validateCaveatAccounts(caveat.value, getInternalAccounts), }, + [CaveatTypes.restrictNetworkSwitching]: { + type: CaveatTypes.restrictNetworkSwitching, + validator: (caveat, _origin, _target) => + validateCaveatNetworks(caveat.value, findNetworkClientIdByChainId), + /** + * @param {any[]} leftValue + * @param {any[]} rightValue + * @returns {[any[], any[]]} + */ + merger: (leftValue, rightValue) => { + const newValue = Array.from(new Set([...leftValue, ...rightValue])); + const diff = newValue.filter((value) => !leftValue.includes(value)); + return [newValue, diff]; + }, + }, ///: BEGIN:ONLY_INCLUDE_IF(preinstalled-snaps,external-snaps) ...snapsCaveatsSpecifications, ...snapsEndowmentCaveatSpecifications, @@ -187,6 +211,46 @@ export const getPermissionSpecifications = ({ } }, }, + [PermissionKeys.permittedChains]: { + permissionType: PermissionType.Endowment, + targetName: PermissionKeys.permittedChains, + allowedCaveats: [CaveatTypes.restrictNetworkSwitching], + factory: (permissionOptions, requestData) => { + if (requestData === undefined) { + return constructPermission({ + ...permissionOptions, + }); + } + + if (!requestData.approvedChainIds) { + throw new Error( + `${PermissionKeys.permittedChains}: No approved networks specified.`, + ); + } + + return constructPermission({ + ...permissionOptions, + caveats: [ + CaveatFactories[CaveatTypes.restrictNetworkSwitching]( + requestData.approvedChainIds, + ), + ], + }); + }, + endowmentGetter: async (_getterOptions) => undefined, + validator: (permission, _origin, _target) => { + const { caveats } = permission; + if ( + !caveats || + caveats.length !== 1 || + caveats[0].type !== CaveatTypes.restrictNetworkSwitching + ) { + throw new Error( + `${PermissionKeys.permittedChains} error: Invalid caveats. There must be a single caveat of type "${CaveatTypes.restrictNetworkSwitching}".`, + ); + } + }, + }, }); /** @@ -226,6 +290,36 @@ function validateCaveatAccounts(accounts, getInternalAccounts) { }); } +/** + * Validates the networks associated with a caveat. Ensures that + * the networks value is an array of valid chain IDs. + * + * @param {string[]} chainIdsForCaveat - The list of chain IDs to validate. + * @param {function(string): string} findNetworkClientIdByChainId - Function to find network client ID by chain ID. + * @throws {Error} If the chainIdsForCaveat is not a non-empty array of valid chain IDs. + */ +function validateCaveatNetworks( + chainIdsForCaveat, + findNetworkClientIdByChainId, +) { + if (!Array.isArray(chainIdsForCaveat) || chainIdsForCaveat.length === 0) { + throw new Error( + `${PermissionKeys.permittedChains} error: Expected non-empty array of chainIds.`, + ); + } + + chainIdsForCaveat.forEach((chainId) => { + try { + findNetworkClientIdByChainId(chainId); + } catch (e) { + console.error(e); + throw new Error( + `${PermissionKeys.permittedChains} error: Received unrecognized chainId: "${chainId}". Please try adding the network first via wallet_addEthereumChain.`, + ); + } + }); +} + /** * All unrestricted methods recognized by the PermissionController. * Unrestricted methods are ignored by the permission system, but every diff --git a/app/core/Permissions/specifications.test.js b/app/core/Permissions/specifications.test.js index ce979964677..9dd34c5d5a4 100644 --- a/app/core/Permissions/specifications.test.js +++ b/app/core/Permissions/specifications.test.js @@ -2,6 +2,7 @@ import { CaveatTypes, RestrictedMethods } from './constants'; import { getCaveatSpecifications, getPermissionSpecifications, + PermissionKeys, unrestrictedMethods, } from './specifications'; import { EthAccountType, EthMethod } from '@metamask/keyring-api'; @@ -22,7 +23,7 @@ describe('PermissionController specifications', () => { describe('caveat specifications', () => { it('getCaveatSpecifications returns the expected specifications object', () => { const caveatSpecifications = getCaveatSpecifications({}); - expect(Object.keys(caveatSpecifications)).toHaveLength(12); + expect(Object.keys(caveatSpecifications)).toHaveLength(13); expect( caveatSpecifications[CaveatTypes.restrictReturnedAccounts].type, ).toStrictEqual(CaveatTypes.restrictReturnedAccounts); @@ -187,10 +188,13 @@ describe('PermissionController specifications', () => { describe('permission specifications', () => { it('getPermissionSpecifications returns the expected specifications object', () => { const permissionSpecifications = getPermissionSpecifications({}); - expect(Object.keys(permissionSpecifications)).toHaveLength(1); + expect(Object.keys(permissionSpecifications)).toHaveLength(2); expect( permissionSpecifications[RestrictedMethods.eth_accounts].targetName, ).toStrictEqual(RestrictedMethods.eth_accounts); + expect( + permissionSpecifications[PermissionKeys.permittedChains].targetName, + ).toStrictEqual(PermissionKeys.permittedChains); }); describe('eth_accounts', () => { diff --git a/app/core/RPCMethods/RPCMethodMiddleware.test.ts b/app/core/RPCMethods/RPCMethodMiddleware.test.ts index 70bfb724c8d..04deabc6eb9 100644 --- a/app/core/RPCMethods/RPCMethodMiddleware.test.ts +++ b/app/core/RPCMethods/RPCMethodMiddleware.test.ts @@ -398,6 +398,7 @@ describe('getRpcMethodMiddleware', () => { }), caveatSpecifications: getCaveatSpecifications({ getInternalAccounts: mockGetInternalAccounts, + findNetworkClientIdByChainId: jest.fn(), }), // @ts-expect-error Typecast permissionType from getPermissionSpecifications to be of type PermissionType.RestrictedMethod permissionSpecifications: { diff --git a/app/core/RPCMethods/lib/ethereum-chain-utils.js b/app/core/RPCMethods/lib/ethereum-chain-utils.js new file mode 100644 index 00000000000..9769933d577 --- /dev/null +++ b/app/core/RPCMethods/lib/ethereum-chain-utils.js @@ -0,0 +1,311 @@ +import { rpcErrors } from '@metamask/rpc-errors'; +import validUrl from 'valid-url'; +import { isSafeChainId } from '@metamask/controller-utils'; +import { jsonRpcRequest } from '../../../util/jsonRpcRequest'; +import { + getDecimalChainId, + isPrefixedFormattedHexString, + isChainPermissionsFeatureEnabled, +} from '../../../util/networks'; +import { + CaveatFactories, + PermissionKeys, +} from '../../../core/Permissions/specifications'; +import { CaveatTypes } from '../../../core/Permissions/constants'; +import { PermissionDoesNotExistError } from '@metamask/permission-controller'; + +const EVM_NATIVE_TOKEN_DECIMALS = 18; + +export function validateChainId(chainId) { + const _chainId = typeof chainId === 'string' && chainId.toLowerCase(); + + if (!isPrefixedFormattedHexString(_chainId)) { + throw rpcErrors.invalidParams( + `Expected 0x-prefixed, unpadded, non-zero hexadecimal string 'chainId'. Received:\n${chainId}`, + ); + } + + if (!isSafeChainId(_chainId)) { + throw rpcErrors.invalidParams( + `Invalid chain ID "${_chainId}": numerical value greater than max safe value. Received:\n${chainId}`, + ); + } + + return _chainId; +} + +export function validateAddEthereumChainParams(params) { + if (!params || !params?.[0] || typeof params[0] !== 'object') { + throw rpcErrors.invalidParams({ + message: `Expected single, object parameter. Received:\n${JSON.stringify( + params, + )}`, + }); + } + + const [ + { + chainId, + chainName: rawChainName = null, + blockExplorerUrls = null, + nativeCurrency = null, + rpcUrls, + }, + ] = params; + + const allowedKeys = { + chainId: true, + chainName: true, + blockExplorerUrls: true, + nativeCurrency: true, + rpcUrls: true, + iconUrls: true, + }; + + const extraKeys = Object.keys(params[0]).filter((key) => !allowedKeys[key]); + if (extraKeys.length) { + throw rpcErrors.invalidParams( + `Received unexpected keys on object parameter. Unsupported keys:\n${extraKeys}`, + ); + } + const _chainId = validateChainId(chainId); + + const firstValidRPCUrl = validateRpcUrls(rpcUrls); + + const firstValidBlockExplorerUrl = + validateBlockExplorerUrls(blockExplorerUrls); + + const chainName = validateChainName(rawChainName); + + const ticker = validateNativeCurrency(nativeCurrency); + + return { + chainId: _chainId, + chainName, + firstValidRPCUrl, + firstValidBlockExplorerUrl, + ticker, + }; +} + +function validateRpcUrls(rpcUrls) { + const dirtyFirstValidRPCUrl = Array.isArray(rpcUrls) + ? rpcUrls.find((rpcUrl) => validUrl.isHttpsUri(rpcUrl)) + : null; + + const firstValidRPCUrl = dirtyFirstValidRPCUrl + ? dirtyFirstValidRPCUrl.replace(/([^/])\/+$/g, '$1') + : dirtyFirstValidRPCUrl; + + if (!firstValidRPCUrl) { + throw rpcErrors.invalidParams( + `Expected an array with at least one valid string HTTPS url 'rpcUrls', Received:\n${rpcUrls}`, + ); + } + + return firstValidRPCUrl; +} + +function validateBlockExplorerUrls(blockExplorerUrls) { + const firstValidBlockExplorerUrl = + blockExplorerUrls !== null && Array.isArray(blockExplorerUrls) + ? blockExplorerUrls.find((blockExplorerUrl) => + validUrl.isHttpsUri(blockExplorerUrl), + ) + : null; + + if (blockExplorerUrls !== null && !firstValidBlockExplorerUrl) { + throw rpcErrors.invalidParams( + `Expected null or array with at least one valid string HTTPS URL 'blockExplorerUrl'. Received: ${blockExplorerUrls}`, + ); + } + + return firstValidBlockExplorerUrl; +} + +function validateChainName(rawChainName) { + if (typeof rawChainName !== 'string' || !rawChainName) { + throw rpcErrors.invalidParams({ + message: `Expected non-empty string 'chainName'. Received:\n${rawChainName}`, + }); + } + return rawChainName.length > 100 + ? rawChainName.substring(0, 100) + : rawChainName; +} + +function validateNativeCurrency(nativeCurrency) { + if (nativeCurrency !== null) { + if (typeof nativeCurrency !== 'object' || Array.isArray(nativeCurrency)) { + throw rpcErrors.invalidParams({ + message: `Expected null or object 'nativeCurrency'. Received:\n${nativeCurrency}`, + }); + } + if (nativeCurrency.decimals !== EVM_NATIVE_TOKEN_DECIMALS) { + throw rpcErrors.invalidParams({ + message: `Expected the number 18 for 'nativeCurrency.decimals' when 'nativeCurrency' is provided. Received: ${nativeCurrency.decimals}`, + }); + } + + if (!nativeCurrency.symbol || typeof nativeCurrency.symbol !== 'string') { + throw rpcErrors.invalidParams({ + message: `Expected a string 'nativeCurrency.symbol'. Received: ${nativeCurrency.symbol}`, + }); + } + } + const ticker = nativeCurrency?.symbol || 'ETH'; + + if (typeof ticker !== 'string' || ticker.length < 2 || ticker.length > 6) { + throw rpcErrors.invalidParams({ + message: `Expected 2-6 character string 'nativeCurrency.symbol'. Received:\n${ticker}`, + }); + } + + return ticker; +} + +export async function validateRpcEndpoint(rpcUrl, chainId) { + let endpointChainId; + try { + endpointChainId = await jsonRpcRequest(rpcUrl, 'eth_chainId'); + } catch (err) { + throw rpcErrors.internal({ + message: `Request for method 'eth_chainId on ${rpcUrl} failed`, + data: { networkErr: err }, + }); + } + if (chainId !== endpointChainId) { + throw rpcErrors.invalidParams({ + message: `Chain ID returned by RPC URL ${rpcUrl} does not match ${chainId}`, + data: { chainId: endpointChainId }, + }); + } +} + +export function findExistingNetwork(chainId, networkConfigurations) { + const existingEntry = Object.entries(networkConfigurations).find( + ([, networkConfiguration]) => networkConfiguration.chainId === chainId, + ); + if (existingEntry) { + const [, networkConfiguration] = existingEntry; + const networkConfigurationId = + networkConfiguration.rpcEndpoints[ + networkConfiguration.defaultRpcEndpointIndex + ].networkClientId; + return [networkConfigurationId, networkConfiguration]; + } + return; +} + +export async function switchToNetwork({ + network, + chainId, + controllers, + requestUserApproval, + analytics, + origin, + isAddNetworkFlow = false, +}) { + const { + CurrencyRateController, + NetworkController, + PermissionController, + SelectedNetworkController, + } = controllers; + const getCaveat = ({ target, caveatType }) => { + try { + return PermissionController.getCaveat(origin, target, caveatType); + } catch (e) { + if (e instanceof PermissionDoesNotExistError) { + // suppress expected error in case that the origin + // does not have the target permission yet + } else { + throw e; + } + } + + return undefined; + }; + const [networkConfigurationId, networkConfiguration] = network; + const requestData = { + rpcUrl: + networkConfiguration.rpcEndpoints[ + networkConfiguration.defaultRpcEndpointIndex + ], + chainId, + chainName: + networkConfiguration.name || + networkConfiguration.chainName || + networkConfiguration.nickname || + networkConfiguration.shortName, + ticker: networkConfiguration.ticker || 'ETH', + chainColor: networkConfiguration.color, + }; + const analyticsParams = { + chain_id: getDecimalChainId(chainId), + source: 'Custom Network API', + symbol: networkConfiguration?.ticker || 'ETH', + ...analytics, + }; + + // for some reason this extra step is necessary for accessing the env variable in test environment + const chainPermissionsFeatureEnabled = + { ...process.env }?.NODE_ENV === 'test' + ? { ...process.env }?.MM_CHAIN_PERMISSIONS === '1' + : isChainPermissionsFeatureEnabled; + + const { value: permissionedChainIds } = + getCaveat({ + target: PermissionKeys.permittedChains, + caveatType: CaveatTypes.restrictNetworkSwitching, + }) ?? {}; + + const shouldGrantPermissions = + chainPermissionsFeatureEnabled && + (!permissionedChainIds || !permissionedChainIds.includes(chainId)); + + const requestModalType = isAddNetworkFlow ? 'new' : 'switch'; + + const shouldShowRequestModal = + (!isAddNetworkFlow && shouldGrantPermissions) || + !chainPermissionsFeatureEnabled; + + if (shouldShowRequestModal) { + await requestUserApproval({ + type: 'SWITCH_ETHEREUM_CHAIN', + requestData: { ...requestData, type: requestModalType }, + }); + } + + if (shouldGrantPermissions) { + await PermissionController.grantPermissionsIncremental({ + subject: { origin }, + approvedPermissions: { + [PermissionKeys.permittedChains]: { + caveats: [ + CaveatFactories[CaveatTypes.restrictNetworkSwitching]([chainId]), + ], + }, + }, + }); + } + + const originHasAccountsPermission = PermissionController.hasPermission( + origin, + 'eth_accounts', + ); + + if (process.env.MM_PER_DAPP_SELECTED_NETWORK && originHasAccountsPermission) { + SelectedNetworkController.setNetworkClientIdForDomain( + origin, + networkConfigurationId || networkConfiguration.networkType, + ); + } else { + CurrencyRateController.updateExchangeRate(requestData.ticker); + NetworkController.setActiveNetwork( + networkConfigurationId || networkConfiguration.networkType, + ); + } + + return analyticsParams; +} diff --git a/app/core/RPCMethods/wallet_addEthereumChain.js b/app/core/RPCMethods/wallet_addEthereumChain.js index 1a0719dd8b0..6bbc661fce7 100644 --- a/app/core/RPCMethods/wallet_addEthereumChain.js +++ b/app/core/RPCMethods/wallet_addEthereumChain.js @@ -1,13 +1,7 @@ import { InteractionManager } from 'react-native'; -import validUrl from 'valid-url'; -import { ChainId, isSafeChainId } from '@metamask/controller-utils'; -import { jsonRpcRequest } from '../../util/jsonRpcRequest'; +import { ChainId } from '@metamask/controller-utils'; import Engine from '../Engine'; import { providerErrors, rpcErrors } from '@metamask/rpc-errors'; -import { - getDecimalChainId, - isPrefixedFormattedHexString, -} from '../../util/networks'; import { MetaMetricsEvents, MetaMetrics } from '../../core/Analytics'; import { selectChainId, @@ -15,10 +9,14 @@ import { } from '../../selectors/networkController'; import { store } from '../../store'; import checkSafeNetwork from './networkChecker.util'; +import { + validateAddEthereumChainParams, + validateRpcEndpoint, + switchToNetwork, +} from './lib/ethereum-chain-utils'; +import { getDecimalChainId } from '../../util/networks'; import { RpcEndpointType } from '@metamask/network-controller'; -const EVM_NATIVE_TOKEN_DECIMALS = 18; - const waitForInteraction = async () => new Promise((resolve) => { InteractionManager.runAfterInteractions(() => { @@ -46,104 +44,35 @@ const wallet_addEthereumChain = async ({ startApprovalFlow, endApprovalFlow, }) => { - const { CurrencyRateController, NetworkController, ApprovalController } = - Engine.context; - - if (!req.params?.[0] || typeof req.params[0] !== 'object') { - throw rpcErrors.invalidParams({ - message: `Expected single, object parameter. Received:\n${JSON.stringify( - req.params, - )}`, - }); - } + const { + CurrencyRateController, + NetworkController, + ApprovalController, + PermissionController, + SelectedNetworkController, + } = Engine.context; - const params = req.params[0]; + const { origin } = req; + const params = validateAddEthereumChainParams(req.params); const { chainId, - chainName: rawChainName = null, - blockExplorerUrls = null, - nativeCurrency = null, - rpcUrls, + chainName, + firstValidRPCUrl, + firstValidBlockExplorerUrl, + ticker, } = params; - const allowedKeys = { - chainId: true, - chainName: true, - blockExplorerUrls: true, - nativeCurrency: true, - rpcUrls: true, - iconUrls: true, - }; - - const extraKeys = Object.keys(params).filter((key) => !allowedKeys[key]); - if (extraKeys.length) { - throw rpcErrors.invalidParams( - `Received unexpected keys on object parameter. Unsupported keys:\n${extraKeys}`, - ); - } - - const dirtyFirstValidRPCUrl = Array.isArray(rpcUrls) - ? rpcUrls.find((rpcUrl) => validUrl.isHttpsUri(rpcUrl)) - : null; - // Remove trailing slashes - const firstValidRPCUrl = dirtyFirstValidRPCUrl - ? // https://github.com/MetaMask/mobile-planning/issues/1589 - dirtyFirstValidRPCUrl.replace(/([^/])\/+$/g, '$1') - : dirtyFirstValidRPCUrl; - - const firstValidBlockExplorerUrl = - blockExplorerUrls !== null && Array.isArray(blockExplorerUrls) - ? blockExplorerUrls.find((blockExplorerUrl) => - validUrl.isHttpsUri(blockExplorerUrl), - ) - : null; - - if (!firstValidRPCUrl) { - throw rpcErrors.invalidParams( - `Expected an array with at least one valid string HTTPS url 'rpcUrls', Received:\n${rpcUrls}`, - ); - } - - if (blockExplorerUrls !== null && !firstValidBlockExplorerUrl) { - throw rpcErrors.invalidParams( - `Expected null or array with at least one valid string HTTPS URL 'blockExplorerUrl'. Received: ${blockExplorerUrls}`, - ); - } - - const _chainId = typeof chainId === 'string' && chainId.toLowerCase(); - - if (!isPrefixedFormattedHexString(_chainId)) { - throw rpcErrors.invalidParams( - `Expected 0x-prefixed, unpadded, non-zero hexadecimal string 'chainId'. Received:\n${chainId}`, - ); - } - - if (!isSafeChainId(_chainId)) { - throw rpcErrors.invalidParams( - `Invalid chain ID "${_chainId}": numerical value greater than max safe value. Received:\n${chainId}`, - ); - } - - if (typeof rawChainName !== 'string' || !rawChainName) { - throw rpcErrors.invalidParams({ - message: `Expected non-empty string 'chainName'. Received:\n${rawChainName}`, - }); - } - - const chainName = rawChainName.slice(0, 100); - //TODO: Remove aurora from default chains in @metamask/controller-utils const actualChains = { ...ChainId, aurora: undefined }; - if (Object.values(actualChains).find((value) => value === _chainId)) { + if (Object.values(actualChains).find((value) => value === chainId)) { throw rpcErrors.invalidParams(`May not specify default MetaMask chain.`); } const networkConfigurations = selectNetworkConfigurations(store.getState()); const existingEntry = Object.entries(networkConfigurations).find( - ([, networkConfiguration]) => networkConfiguration.chainId === _chainId, + ([, networkConfiguration]) => networkConfiguration.chainId === chainId, ); - if (existingEntry) { const [chainId, networkConfiguration] = existingEntry; const currentChainId = selectChainId(store.getState()); @@ -188,38 +117,32 @@ const wallet_addEthereumChain = async ({ ); const analyticsParams = { - chain_id: getDecimalChainId(_chainId), + chain_id: getDecimalChainId(chainId), source: 'Custom Network API', symbol: networkConfiguration.ticker, ...analytics, }; - try { - await requestUserApproval({ - type: 'SWITCH_ETHEREUM_CHAIN', - requestData: { - rpcUrl: networkConfiguration.rpcUrl, - chainId: _chainId, - chainName: networkConfiguration.name, - ticker: networkConfiguration.nativeCurrency, - type: 'switch', - }, - }); - } catch (e) { - MetaMetrics.getInstance().trackEvent( - MetaMetricsEvents.NETWORK_REQUEST_REJECTED, - analyticsParams, - ); - throw providerErrors.userRejectedRequest(); - } - - CurrencyRateController.updateExchangeRate(networkConfiguration.ticker); const { networkClientId } = - networkConfiguration?.rpcEndpoints?.[ + networkConfiguration.rpcEndpoints[ networkConfiguration.defaultRpcEndpointIndex - ] ?? {}; + ]; - NetworkController.setActiveNetwork(networkClientId); + const network = [networkClientId, clonedNetwork]; + await switchToNetwork({ + network, + chainId, + controllers: { + CurrencyRateController, + NetworkController, + PermissionController, + SelectedNetworkController, + }, + requestUserApproval, + analytics, + origin, + isAddNetworkFlow: true, + }); MetaMetrics.getInstance().trackEvent( MetaMetricsEvents.NETWORK_SWITCHED, @@ -229,53 +152,9 @@ const wallet_addEthereumChain = async ({ res.result = null; return; } - - let endpointChainId; - - try { - endpointChainId = await jsonRpcRequest(firstValidRPCUrl, 'eth_chainId'); - } catch (err) { - throw rpcErrors.internal({ - message: `Request for method 'eth_chainId on ${firstValidRPCUrl} failed`, - data: { networkErr: err }, - }); - } - - if (_chainId !== endpointChainId) { - throw rpcErrors.invalidParams({ - message: `Chain ID returned by RPC URL ${firstValidRPCUrl} does not match ${_chainId}`, - data: { chainId: endpointChainId }, - }); - } - - if (nativeCurrency !== null) { - if (typeof nativeCurrency !== 'object' || Array.isArray(nativeCurrency)) { - throw rpcErrors.invalidParams({ - message: `Expected null or object 'nativeCurrency'. Received:\n${nativeCurrency}`, - }); - } - if (nativeCurrency.decimals !== EVM_NATIVE_TOKEN_DECIMALS) { - throw rpcErrors.invalidParams({ - message: `Expected the number 18 for 'nativeCurrency.decimals' when 'nativeCurrency' is provided. Received: ${nativeCurrency.decimals}`, - }); - } - - if (!nativeCurrency.symbol || typeof nativeCurrency.symbol !== 'string') { - throw rpcErrors.invalidParams({ - message: `Expected a string 'nativeCurrency.symbol'. Received: ${nativeCurrency.symbol}`, - }); - } - } - const ticker = nativeCurrency?.symbol || 'ETH'; - - if (typeof ticker !== 'string' || ticker.length < 2 || ticker.length > 6) { - throw rpcErrors.invalidParams({ - message: `Expected 2-6 character string 'nativeCurrency.symbol'. Received:\n${ticker}`, - }); - } - + await validateRpcEndpoint(firstValidRPCUrl, chainId); const requestData = { - chainId: _chainId, + chainId, blockExplorerUrl: firstValidBlockExplorerUrl, chainName, rpcUrl: firstValidRPCUrl, @@ -283,33 +162,25 @@ const wallet_addEthereumChain = async ({ }; const alerts = await checkSafeNetwork( - getDecimalChainId(_chainId), + getDecimalChainId(chainId), requestData.rpcUrl, requestData.chainName, requestData.ticker, ); - requestData.alerts = alerts; - const analyticsParamsAdd = { - chain_id: getDecimalChainId(_chainId), + MetaMetrics.getInstance().trackEvent(MetaMetricsEvents.NETWORK_REQUESTED, { + chain_id: getDecimalChainId(chainId), source: 'Custom Network API', symbol: ticker, ...analytics, - }; - - MetaMetrics.getInstance().trackEvent( - MetaMetricsEvents.NETWORK_REQUESTED, - analyticsParamsAdd, - ); - + }); // Remove all existing approvals, including other add network requests. ApprovalController.clear(providerErrors.userRejectedRequest()); // If existing approval request was an add network request, wait for // it to be rejected and for the corresponding approval flow to be ended. await waitForInteraction(); - const { id: approvalFlowId } = startApprovalFlow(); try { @@ -318,16 +189,21 @@ const wallet_addEthereumChain = async ({ type: 'ADD_ETHEREUM_CHAIN', requestData, }); - } catch (e) { + } catch (error) { MetaMetrics.getInstance().trackEvent( MetaMetricsEvents.NETWORK_REQUEST_REJECTED, - analyticsParamsAdd, + { + chain_id: getDecimalChainId(chainId), + source: 'Custom Network API', + symbol: ticker, + ...analytics, + }, ); throw providerErrors.userRejectedRequest(); } - const networkConfigurationId = await NetworkController.addNetwork({ + const networkConfiguration = await NetworkController.addNetwork({ chainId, - blockExplorerUrls, + blockExplorerUrls: [firstValidBlockExplorerUrl], defaultRpcEndpointIndex: 0, defaultBlockExplorerUrlIndex: 0, name: chainName, @@ -341,25 +217,38 @@ const wallet_addEthereumChain = async ({ ], }); - MetaMetrics.getInstance().trackEvent( - MetaMetricsEvents.NETWORK_ADDED, - analyticsParamsAdd, - ); - - await waitForInteraction(); - - await requestUserApproval({ - type: 'SWITCH_ETHEREUM_CHAIN', - requestData: { ...requestData, type: 'new' }, + MetaMetrics.getInstance().trackEvent(MetaMetricsEvents.NETWORK_ADDED, { + chain_id: getDecimalChainId(chainId), + source: 'Custom Network API', + symbol: ticker, + ...analytics, }); - CurrencyRateController.updateExchangeRate(ticker); const { networkClientId } = - networkConfigurationId?.rpcEndpoints?.[ - networkConfigurationId.defaultRpcEndpointIndex + networkConfiguration?.rpcEndpoints?.[ + networkConfiguration.defaultRpcEndpointIndex ] ?? {}; - NetworkController.setActiveNetwork(networkClientId); + const network = [networkClientId, networkConfiguration]; + const analyticsParams = await switchToNetwork({ + network, + chainId, + controllers: { + CurrencyRateController, + NetworkController, + PermissionController, + SelectedNetworkController, + }, + requestUserApproval, + analytics, + origin, + isAddNetworkFlow: true, + }); + + MetaMetrics.getInstance().trackEvent( + MetaMetricsEvents.NETWORK_SWITCHED, + analyticsParams, + ); } finally { endApprovalFlow({ id: approvalFlowId }); } diff --git a/app/core/RPCMethods/wallet_addEthereumChain.test.js b/app/core/RPCMethods/wallet_addEthereumChain.test.js index 1c382d609a0..f84e2ebcd29 100644 --- a/app/core/RPCMethods/wallet_addEthereumChain.test.js +++ b/app/core/RPCMethods/wallet_addEthereumChain.test.js @@ -2,6 +2,8 @@ import { InteractionManager } from 'react-native'; import { providerErrors } from '@metamask/rpc-errors'; import wallet_addEthereumChain from './wallet_addEthereumChain'; import Engine from '../Engine'; +import { CaveatFactories, PermissionKeys } from '../Permissions/specifications'; +import { CaveatTypes } from '../Permissions/constants'; import { mockNetworkState } from '../../util/test/network'; const mockEngine = Engine; @@ -14,6 +16,17 @@ const correctParams = { rpcUrls: ['https://rpc.gnosischain.com'], }; +const existingNetworkConfiguration = { + id: 'test-network-configuration-id', + chainId: '0x2', + rpcUrl: 'https://rpc.test-chain.com', + ticker: 'TST', + nickname: 'Test Chain', + rpcPrefs: { + blockExplorerUrl: 'https://explorer.test-chain.com', + }, +}; + jest.mock('../Engine', () => ({ init: () => mockEngine.init({}), context: { @@ -21,6 +34,7 @@ jest.mock('../Engine', () => ({ setActiveNetwork: jest.fn(), upsertNetworkConfiguration: jest.fn(), addNetwork: jest.fn(), + updateNetwork: jest.fn(), }, CurrencyRateController: { updateExchangeRate: jest.fn(), @@ -28,6 +42,15 @@ jest.mock('../Engine', () => ({ ApprovalController: { clear: jest.fn(), }, + PermissionController: { + hasPermission: jest.fn().mockReturnValue(true), + grantPermissionsIncremental: jest.fn(), + requestPermissionsIncremental: jest.fn(), + getCaveat: jest.fn(), + }, + SelectedNetworkController: { + setNetworkClientIdForDomain: jest.fn(), + }, }, })); @@ -37,12 +60,17 @@ jest.mock('../../store', () => ({ engine: { backgroundState: { NetworkController: { - ...mockNetworkState({ - chainId: '0x1', - id: 'Mainnet', - nickname: 'Mainnet', - ticker: 'ETH', - }), + ...mockNetworkState( + { + chainId: '0x1', + id: 'Mainnet', + nickname: 'Mainnet', + ticker: 'ETH', + }, + { + ...existingNetworkConfiguration, + }, + ), }, }, }, @@ -55,6 +83,7 @@ describe('RPC Method - wallet_addEthereumChain', () => { let otherOptions; beforeEach(() => { + jest.clearAllMocks(); otherOptions = { res: {}, addCustomNetworkRequest: {}, @@ -71,6 +100,8 @@ describe('RPC Method - wallet_addEthereumChain', () => { mockFetch = jest.fn().mockImplementation(async (url) => { if (url === 'https://rpc.gnosischain.com') { return { json: () => Promise.resolve({ result: '0x64' }) }; + } else if (url === 'https://different-rpc-url.com') { + return { json: () => Promise.resolve({ result: '0x2' }) }; } else if (url === 'https://chainid.network/chains.json') { return { json: () => @@ -262,6 +293,14 @@ describe('RPC Method - wallet_addEthereumChain', () => { describe('Approval Flow', () => { it('should start and end a new approval flow if chain does not already exist', async () => { + jest + .spyOn(Engine.context.NetworkController, 'addNetwork') + .mockResolvedValue({ + id: '1', + chainId: '0x64', + rpcEndpoints: [correctParams.rpcUrls[0]], + defaultRpcEndpointIndex: 0, + }); await wallet_addEthereumChain({ req: { params: [correctParams], @@ -301,4 +340,154 @@ describe('RPC Method - wallet_addEthereumChain', () => { expect(Engine.context.ApprovalController.clear).toBeCalledTimes(1); }); }); + + it('should not modify/add permissions', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + await wallet_addEthereumChain({ + req: { + params: [correctParams], + }, + ...otherOptions, + }); + + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledTimes(0); + }); + + it('should correctly add and switch to a new chain when chain is not already in wallet state ', async () => { + const spyOnAddNetwork = jest + .spyOn(Engine.context.NetworkController, 'addNetwork') + .mockResolvedValue({ + id: '1', + chainId: '0x64', + rpcEndpoints: [correctParams.rpcUrls[0]], + defaultRpcEndpointIndex: 0, + }); + + const spyOnSetActiveNetwork = jest.spyOn( + Engine.context.NetworkController, + 'setActiveNetwork', + ); + const spyOnUpdateExchangeRate = jest.spyOn( + Engine.context.CurrencyRateController, + 'updateExchangeRate', + ); + + await wallet_addEthereumChain({ + req: { + params: [correctParams], + origin: 'https://example.com', + }, + ...otherOptions, + }); + + expect(spyOnAddNetwork).toHaveBeenCalledTimes(1); + expect(spyOnAddNetwork).toHaveBeenCalledWith( + expect.objectContaining({ + chainId: correctParams.chainId, + blockExplorerUrls: correctParams.blockExplorerUrls, + nativeCurrency: correctParams.nativeCurrency.symbol, + name: correctParams.chainName, + }), + ); + expect(spyOnSetActiveNetwork).toHaveBeenCalledTimes(1); + expect(spyOnUpdateExchangeRate).toHaveBeenCalledTimes(1); + }); + + it('should not add a networkConfiguration that has a chainId that already exists in wallet state, and should switch to the existing network', async () => { + const spyOnAddNetwork = jest.spyOn( + Engine.context.NetworkController, + 'addNetwork', + ); + + const spyOnSetActiveNetwork = jest.spyOn( + Engine.context.NetworkController, + 'setActiveNetwork', + ); + const spyOnUpdateExchangeRate = jest.spyOn( + Engine.context.CurrencyRateController, + 'updateExchangeRate', + ); + + const existingParams = { + chainId: existingNetworkConfiguration.chainId, + rpcUrls: ['https://different-rpc-url.com'], + chainName: existingNetworkConfiguration.nickname, + nativeCurrency: { + name: existingNetworkConfiguration.ticker, + symbol: existingNetworkConfiguration.ticker, + decimals: 18, + }, + }; + + await wallet_addEthereumChain({ + req: { + params: [existingParams], + origin: 'https://example.com', + }, + ...otherOptions, + }); + + expect(spyOnAddNetwork).not.toHaveBeenCalled(); + expect(spyOnSetActiveNetwork).toHaveBeenCalledTimes(1); + expect(spyOnUpdateExchangeRate).toHaveBeenCalledTimes(1); + }); + + describe('MM_CHAIN_PERMISSIONS is enabled', () => { + beforeAll(() => { + process.env.MM_CHAIN_PERMISSIONS = 1; + }); + afterAll(() => { + process.env.MM_CHAIN_PERMISSIONS = 0; + }); + afterEach(() => { + jest.clearAllMocks(); + }); + it('should grant permissions when chain is not already permitted', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + await wallet_addEthereumChain({ + req: { + params: [correctParams], + origin: 'https://example.com', + }, + ...otherOptions, + }); + + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledTimes(1); + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledWith({ + subject: { origin: 'https://example.com' }, + approvedPermissions: { + [PermissionKeys.permittedChains]: { + caveats: [ + CaveatFactories[CaveatTypes.restrictNetworkSwitching](['0x64']), + ], + }, + }, + }); + }); + + it('should not grant permissions when chain is already permitted', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + jest + .spyOn(Engine.context.PermissionController, 'getCaveat') + .mockReturnValue({ value: ['0x64'] }); + await wallet_addEthereumChain({ + req: { + params: [correctParams], + origin: 'https://example.com', + }, + ...otherOptions, + }); + + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledTimes(0); + }); + }); }); diff --git a/app/core/RPCMethods/wallet_switchEthereumChain.js b/app/core/RPCMethods/wallet_switchEthereumChain.js index d1cc4af3bc6..d025a1009ca 100644 --- a/app/core/RPCMethods/wallet_switchEthereumChain.js +++ b/app/core/RPCMethods/wallet_switchEthereumChain.js @@ -1,15 +1,13 @@ import Engine from '../Engine'; import { providerErrors, rpcErrors } from '@metamask/rpc-errors'; -import { - getDecimalChainId, - getDefaultNetworkByChainId, - isPrefixedFormattedHexString, -} from '../../util/networks'; import { MetaMetricsEvents, MetaMetrics } from '../../core/Analytics'; import { selectNetworkConfigurations } from '../../selectors/networkController'; import { store } from '../../store'; -import { NetworksTicker, isSafeChainId } from '@metamask/controller-utils'; -import { RestrictedMethods } from '../Permissions/constants'; +import { + validateChainId, + findExistingNetwork, + switchToNetwork, +} from './lib/ethereum-chain-utils'; const wallet_switchEthereumChain = async ({ req, @@ -25,7 +23,6 @@ const wallet_switchEthereumChain = async ({ } = Engine.context; const params = req.params?.[0]; const { origin } = req; - if (!params || typeof params !== 'object') { throw rpcErrors.invalidParams({ message: `Expected single, object parameter. Received:\n${JSON.stringify( @@ -33,9 +30,7 @@ const wallet_switchEthereumChain = async ({ )}`, }); } - const { chainId } = params; - const allowedKeys = { chainId: true, }; @@ -46,37 +41,16 @@ const wallet_switchEthereumChain = async ({ `Received unexpected keys on object parameter. Unsupported keys:\n${extraKeys}`, ); } - - const _chainId = typeof chainId === 'string' && chainId.toLowerCase(); - - if (!isPrefixedFormattedHexString(_chainId)) { - throw rpcErrors.invalidParams( - `Expected 0x-prefixed, unpadded, non-zero hexadecimal string 'chainId'. Received:\n${chainId}`, - ); - } - - if (!isSafeChainId(_chainId)) { - throw rpcErrors.invalidParams( - `Invalid chain ID "${_chainId}": numerical value greater than max safe value. Received:\n${chainId}`, - ); - } + const _chainId = validateChainId(chainId); const networkConfigurations = selectNetworkConfigurations(store.getState()); - - const existingNetworkDefault = getDefaultNetworkByChainId(_chainId); - const existingEntry = Object.entries(networkConfigurations).find( - ([, networkConfiguration]) => networkConfiguration.chainId === _chainId, - ); - - if (existingEntry || existingNetworkDefault) { + const existingNetwork = findExistingNetwork(_chainId, networkConfigurations); + if (existingNetwork) { const currentDomainSelectedNetworkClientId = - Engine.context.SelectedNetworkController.getNetworkClientIdForDomain( - origin, - ); - + SelectedNetworkController.getNetworkClientIdForDomain(origin); const { configuration: { chainId: currentDomainSelectedChainId }, - } = Engine.context.NetworkController.getNetworkClientById( + } = NetworkController.getNetworkClientById( currentDomainSelectedNetworkClientId, ) || { configuration: {} }; @@ -85,71 +59,20 @@ const wallet_switchEthereumChain = async ({ return; } - let networkConfigurationId, networkConfiguration; - if (existingEntry) { - [, networkConfiguration] = existingEntry; - networkConfigurationId = - networkConfiguration.rpcEndpoints[ - networkConfiguration.defaultRpcEndpointIndex - ].networkClientId; - } - - let requestData; - let analyticsParams = { - chain_id: getDecimalChainId(_chainId), - source: 'Switch Network API', - ...analytics, - }; - if (networkConfiguration) { - requestData = { - rpcUrl: - networkConfiguration.rpcEndpoints[ - networkConfiguration.defaultRpcEndpointIndex - ], - chainId: _chainId, - chainName: networkConfiguration.name, - ticker: networkConfiguration.nativeCurrency, - }; - analyticsParams = { - ...analyticsParams, - symbol: networkConfiguration?.ticker, - }; - } else { - requestData = { - chainId: _chainId, - chainColor: existingNetworkDefault.color, - chainName: existingNetworkDefault.shortName, - ticker: 'ETH', - }; - analyticsParams = { - ...analyticsParams, - }; - } - - await requestUserApproval({ - type: 'SWITCH_ETHEREUM_CHAIN', - requestData: { ...requestData, type: 'switch' }, - }); - - const originHasAccountsPermission = PermissionController.hasPermission( + const analyticsParams = await switchToNetwork({ + network: existingNetwork, + chainId: _chainId, + controllers: { + CurrencyRateController, + NetworkController, + PermissionController, + SelectedNetworkController, + }, + requestUserApproval, + analytics, origin, - RestrictedMethods.eth_accounts, - ); - - if (process.env.MULTICHAIN_V1 && originHasAccountsPermission) { - SelectedNetworkController.setNetworkClientIdForDomain( - origin, - networkConfigurationId || existingNetworkDefault.networkType, - ); - } else if (networkConfiguration) { - CurrencyRateController.updateExchangeRate(networkConfiguration.ticker); - NetworkController.setActiveNetwork(networkConfigurationId); - } else { - // TODO we will need to update this so that each network in the NetworksList has its own ticker - // if we ever add networks that don't have ETH as their base currency - CurrencyRateController.updateExchangeRate(NetworksTicker.mainnet); - NetworkController.setActiveNetwork(existingNetworkDefault.networkType); - } + isAddNetworkFlow: false, + }); MetaMetrics.getInstance().trackEvent( MetaMetricsEvents.NETWORK_SWITCHED, diff --git a/app/core/RPCMethods/wallet_switchEthereumChain.test.js b/app/core/RPCMethods/wallet_switchEthereumChain.test.js index 9f32cdc77de..032a8a6ca5d 100644 --- a/app/core/RPCMethods/wallet_switchEthereumChain.test.js +++ b/app/core/RPCMethods/wallet_switchEthereumChain.test.js @@ -1,4 +1,62 @@ import wallet_switchEthereumChain from './wallet_switchEthereumChain'; +import Engine from '../Engine'; +import { mockNetworkState } from '../../util/test/network'; + +const existingNetworkConfiguration = { + id: 'test-network-configuration-id', + chainId: '0x64', + rpcUrl: 'https://rpc.test-chain.com', + ticker: 'ETH', + nickname: 'Gnosis Chain', + rpcPrefs: { + blockExplorerUrl: 'https://explorer.test-chain.com', + }, +}; +jest.mock('../Engine', () => ({ + context: { + NetworkController: { + setActiveNetwork: jest.fn(), + getNetworkClientById: jest.fn(), + }, + CurrencyRateController: { + updateExchangeRate: jest.fn(), + }, + PermissionController: { + hasPermission: jest.fn().mockReturnValue(true), + grantPermissionsIncremental: jest.fn(), + getCaveat: jest.fn(), + }, + SelectedNetworkController: { + setNetworkClientIdForDomain: jest.fn(), + getNetworkClientIdForDomain: jest.fn(), + }, + }, +})); + +jest.mock('../../store', () => ({ + store: { + getState: jest.fn(() => ({ + engine: { + backgroundState: { + NetworkController: { + ...mockNetworkState( + { + chainId: '0x1', + id: 'Mainnet', + nickname: 'Mainnet', + ticker: 'ETH', + }, + { + ...existingNetworkConfiguration, + }, + ), + }, + }, + }, + })), + }, +})); + const correctParams = { chainId: '0x1', }; @@ -6,9 +64,14 @@ const correctParams = { const otherOptions = { res: {}, switchCustomNetworkRequest: {}, + requestUserApproval: jest.fn(), }; describe('RPC Method - wallet_switchEthereumChain', () => { + afterEach(() => { + jest.clearAllMocks(); + }); + it('should report missing params', async () => { try { await wallet_switchEthereumChain({ @@ -66,4 +129,129 @@ describe('RPC Method - wallet_switchEthereumChain', () => { ); } }); + + it('should should show a modal for user approval and not grant permissions', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + jest + .spyOn( + Engine.context.SelectedNetworkController, + 'getNetworkClientIdForDomain', + ) + .mockReturnValue('mainnet'); + jest + .spyOn(Engine.context.NetworkController, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x1' } }); + const spyOnSetActiveNetwork = jest.spyOn( + Engine.context.NetworkController, + 'setActiveNetwork', + ); + await wallet_switchEthereumChain({ + req: { + params: [{ chainId: '0x64' }], + }, + ...otherOptions, + }); + expect(otherOptions.requestUserApproval).toHaveBeenCalled(); + expect(spyOnGrantPermissionsIncremental).not.toHaveBeenCalled(); + expect(spyOnSetActiveNetwork).toHaveBeenCalledWith( + 'test-network-configuration-id', + ); + }); + + describe('MM_CHAIN_PERMISSIONS is enabled', () => { + beforeAll(() => { + process.env.MM_CHAIN_PERMISSIONS = 1; + }); + afterAll(() => { + process.env.MM_CHAIN_PERMISSIONS = 0; + }); + it('should not change network permissions and should switch without user approval when chain is already permitted', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + jest + .spyOn( + Engine.context.SelectedNetworkController, + 'getNetworkClientIdForDomain', + ) + .mockReturnValue('mainnet'); + jest + .spyOn(Engine.context.NetworkController, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x1' } }); + jest + .spyOn(Engine.context.PermissionController, 'getCaveat') + .mockReturnValue({ value: ['0x64'] }); + + const spyOnSetActiveNetwork = jest.spyOn( + Engine.context.NetworkController, + 'setActiveNetwork', + ); + await wallet_switchEthereumChain({ + req: { + params: [{ chainId: '0x64' }], + }, + ...otherOptions, + }); + + expect(otherOptions.requestUserApproval).not.toHaveBeenCalled(); + expect(spyOnGrantPermissionsIncremental).not.toHaveBeenCalled(); + expect(spyOnSetActiveNetwork).toHaveBeenCalledWith( + 'test-network-configuration-id', + ); + }); + + it('should add network permission and should switch with user approval when requested chain is not permitted', async () => { + const spyOnGrantPermissionsIncremental = jest.spyOn( + Engine.context.PermissionController, + 'grantPermissionsIncremental', + ); + jest + .spyOn( + Engine.context.SelectedNetworkController, + 'getNetworkClientIdForDomain', + ) + .mockReturnValue('mainnet'); + jest + .spyOn(Engine.context.NetworkController, 'getNetworkClientById') + .mockReturnValue({ configuration: { chainId: '0x1' } }); + const spyOnSetActiveNetwork = jest.spyOn( + Engine.context.NetworkController, + 'setActiveNetwork', + ); + jest + .spyOn(Engine.context.PermissionController, 'getCaveat') + .mockReturnValue({ value: [] }); + await wallet_switchEthereumChain({ + req: { + params: [{ chainId: '0x64' }], + origin: 'https://test.com', + }, + ...otherOptions, + }); + expect(otherOptions.requestUserApproval).toHaveBeenCalled(); + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledTimes(1); + expect(spyOnGrantPermissionsIncremental).toHaveBeenCalledWith({ + approvedPermissions: { + 'endowment:permitted-chains': { + caveats: [ + { + type: 'restrictNetworkSwitching', + value: ['0x64'], + }, + ], + }, + }, + subject: { + origin: 'https://test.com', + }, + }); + expect(spyOnSetActiveNetwork).toHaveBeenCalledWith( + 'test-network-configuration-id', + ); + }); + }); }); diff --git a/app/selectors/networkController.ts b/app/selectors/networkController.ts index ef937b27a36..d55ff91e0ea 100644 --- a/app/selectors/networkController.ts +++ b/app/selectors/networkController.ts @@ -156,3 +156,4 @@ export const selectNetworkClientId = createSelector( (networkControllerState: NetworkState) => networkControllerState.selectedNetworkClientId, ); + diff --git a/app/selectors/selectedNetworkController.ts b/app/selectors/selectedNetworkController.ts index 0588c165084..3cb73d681f5 100644 --- a/app/selectors/selectedNetworkController.ts +++ b/app/selectors/selectedNetworkController.ts @@ -95,7 +95,7 @@ export const makeSelectNetworkName = () => chainId, hostname, ) => { - if (!hostname || !process.env.MULTICHAIN_V1) return providerNetworkName; + if (!hostname || !process.env.MM_PER_DAPP_SELECTED_NETWORK) return providerNetworkName; const relevantNetworkClientId = domainNetworkClientId || globalNetworkClientId; return ( @@ -127,7 +127,7 @@ export const makeSelectNetworkImageSource = () => chainId, hostname, ) => { - if (!hostname || !process.env.MULTICHAIN_V1) + if (!hostname || !process.env.MM_PER_DAPP_SELECTED_NETWORK) return providerNetworkImageSource; const relevantNetworkClientId = domainNetworkClientId || globalNetworkClientId; @@ -162,7 +162,7 @@ export const makeSelectChainId = () => chainId, hostname, ) => { - if (!hostname || !process.env.MULTICHAIN_V1) { + if (!hostname || !process.env.MM_PER_DAPP_SELECTED_NETWORK) { return providerChainId; } const relevantNetworkClientId = @@ -193,7 +193,7 @@ export const makeSelectRpcUrl = () => chainId, hostname, ) => { - if (!hostname || !process.env.MULTICHAIN_V1) return providerRpcUrl; + if (!hostname || !process.env.MM_PER_DAPP_SELECTED_NETWORK) return providerRpcUrl; const relevantNetworkClientId = domainNetworkClientId || globalNetworkClientId; return networkConfigurations[chainId]?.rpcEndpoints.find( diff --git a/app/util/networks/index.js b/app/util/networks/index.js index 6a91fa89a9b..74052aa4137 100644 --- a/app/util/networks/index.js +++ b/app/util/networks/index.js @@ -586,3 +586,6 @@ export const deprecatedGetNetworkId = async () => { export const isMultichainVersion1Enabled = process.env.MM_MULTICHAIN_V1_ENABLED === '1'; + +export const isChainPermissionsFeatureEnabled = + process.env.MM_CHAIN_PERMISSIONS === '1'; diff --git a/app/util/test/network.ts b/app/util/test/network.ts index 8f8520a6fab..33fc493f727 100644 --- a/app/util/test/network.ts +++ b/app/util/test/network.ts @@ -40,7 +40,6 @@ export const mockNetworkState = ( 'rpcUrl' in network ? network.rpcUrl : `https://localhost/rpc/${network.chainId}`; - return { chainId: network.chainId, blockExplorerUrls: blockExplorer ? [blockExplorer] : [],