diff --git a/packages/transaction-pay-controller/CHANGELOG.md b/packages/transaction-pay-controller/CHANGELOG.md index 50cecaa5100..cb5fce07023 100644 --- a/packages/transaction-pay-controller/CHANGELOG.md +++ b/packages/transaction-pay-controller/CHANGELOG.md @@ -7,6 +7,14 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +### Added + +- **BREAKING:** Add `AssetsControllerGetStateForTransactionPayAction` to the `AllowedActions` messenger type ([#8163](https://github.com/MetaMask/core/pull/8163)) + +### Changed + +- `getTokenBalance`, `getTokenInfo`, and `getTokenFiatRate` now source token metadata, balances, and pricing from `AssetsControllerGetStateForTransactionPayAction` when the `assetsUnifyState` remote feature flag is enabled, falling back to individual controller state calls otherwise ([#8163](https://github.com/MetaMask/core/pull/8163)) + ## [16.5.0] ### Added diff --git a/packages/transaction-pay-controller/package.json b/packages/transaction-pay-controller/package.json index c7ef24607ca..b6104c7a9fd 100644 --- a/packages/transaction-pay-controller/package.json +++ b/packages/transaction-pay-controller/package.json @@ -51,6 +51,7 @@ "@ethersproject/abi": "^5.7.0", "@ethersproject/contracts": "^5.7.0", "@ethersproject/providers": "^5.7.0", + "@metamask/assets-controller": "^2.3.0", "@metamask/assets-controllers": "^100.2.1", "@metamask/base-controller": "^9.0.0", "@metamask/bridge-controller": "^69.1.0", diff --git a/packages/transaction-pay-controller/src/tests/messenger-mock.ts b/packages/transaction-pay-controller/src/tests/messenger-mock.ts index b6d69d4f22d..89028d7973c 100644 --- a/packages/transaction-pay-controller/src/tests/messenger-mock.ts +++ b/packages/transaction-pay-controller/src/tests/messenger-mock.ts @@ -127,6 +127,8 @@ export function getMessengerMock({ TransactionControllerEstimateGasBatchAction['handler'] > = jest.fn(); + const getAssetsControllerStateMock = jest.fn(); + const messenger: RootMessenger = new Messenger({ namespace: MOCK_ANY_NAMESPACE, }); @@ -241,12 +243,18 @@ export function getMessengerMock({ 'TransactionController:estimateGasBatch', estimateGasBatchMock, ); + + messenger.registerActionHandler( + 'AssetsController:getStateForTransactionPay', + getAssetsControllerStateMock, + ); } const publish = messenger.publish.bind(messenger); return { addTransactionMock, + getAssetsControllerStateMock, addTransactionBatchMock, estimateGasMock, estimateGasBatchMock, diff --git a/packages/transaction-pay-controller/src/types.ts b/packages/transaction-pay-controller/src/types.ts index ff99540dac5..8a3a71714ca 100644 --- a/packages/transaction-pay-controller/src/types.ts +++ b/packages/transaction-pay-controller/src/types.ts @@ -1,3 +1,4 @@ +import type { AssetsControllerGetStateForTransactionPayAction } from '@metamask/assets-controller'; import type { CurrencyRateControllerActions, TokenBalancesControllerGetStateAction, @@ -38,6 +39,7 @@ import type { CONTROLLER_NAME, TransactionPayStrategy } from './constants'; export type AllowedActions = | AccountTrackerControllerGetStateAction + | AssetsControllerGetStateForTransactionPayAction | BridgeControllerActions | BridgeStatusControllerActions | CurrencyRateControllerActions diff --git a/packages/transaction-pay-controller/src/utils/feature-flags.test.ts b/packages/transaction-pay-controller/src/utils/feature-flags.test.ts index 8711713267c..6a550f9d0ae 100644 --- a/packages/transaction-pay-controller/src/utils/feature-flags.test.ts +++ b/packages/transaction-pay-controller/src/utils/feature-flags.test.ts @@ -9,6 +9,7 @@ import { DEFAULT_RELAY_QUOTE_URL, DEFAULT_SLIPPAGE, DEFAULT_STRATEGY_ORDER, + getAssetsUnifyStateFeature, getFallbackGas, DEFAULT_RELAY_EXECUTE_URL, getRelayOriginGasOverhead, @@ -562,6 +563,85 @@ describe('Feature Flags Utils', () => { }); }); + describe('getAssetsUnifyStateFeature', () => { + type AssetsUnifyingState = + | { + enabled: boolean; + featureVersion: string | null; + } + | undefined; + + const failureCases: { + description: string; + assetsUnifyingState: AssetsUnifyingState; + }[] = [ + { + description: 'returns false when assetsUnifyState is not set', + assetsUnifyingState: undefined, + }, + { + description: 'returns false when assetsUnifyState.enabled is false', + assetsUnifyingState: { + enabled: false, + featureVersion: '1', + }, + }, + { + description: + 'returns false when featureVersion does not match expected version', + assetsUnifyingState: { + enabled: true, + featureVersion: '2', + }, + }, + ]; + + const successCases = [ + { + description: + 'returns true when assetsUnifyState is enabled and featureVersion matches', + assetsUnifyingState: { + enabled: true, + featureVersion: '1', + }, + }, + ]; + + const arrangeMocks = (assetsUnifyState: AssetsUnifyingState): void => { + const defaultRemoteFeatureFlagsState = + getDefaultRemoteFeatureFlagControllerState(); + getRemoteFeatureFlagControllerStateMock.mockReturnValue({ + ...defaultRemoteFeatureFlagsState, + remoteFeatureFlags: { + ...defaultRemoteFeatureFlagsState.remoteFeatureFlags, + ...(assetsUnifyState ? { assetsUnifyState } : {}), + }, + }); + }; + + it.each(failureCases)( + '$description', + ({ assetsUnifyingState }: (typeof failureCases)[number]) => { + arrangeMocks(assetsUnifyingState); + + const result = getAssetsUnifyStateFeature(messenger); + + expect(result).toBe(false); + }, + ); + + it.each(successCases)( + '$description', + ({ assetsUnifyingState }: (typeof successCases)[number]) => { + arrangeMocks(assetsUnifyingState); + + const result = getAssetsUnifyStateFeature(messenger); + + expect(result).toBe(true); + }, + ); + }); + describe('getStrategyOrder', () => { it('returns default strategy order when none is set', () => { const strategyOrder = getStrategyOrder(messenger); diff --git a/packages/transaction-pay-controller/src/utils/feature-flags.ts b/packages/transaction-pay-controller/src/utils/feature-flags.ts index 1307150cad0..bdaf08be06a 100644 --- a/packages/transaction-pay-controller/src/utils/feature-flags.ts +++ b/packages/transaction-pay-controller/src/utils/feature-flags.ts @@ -297,6 +297,31 @@ export function getSlippage( return slippage; } +/** + * Get the AssetsUnifyState feature flag state. + * + * @param messenger - Controller messenger. + * @returns True if the assets unify state feature is enabled, false otherwise. + */ +export function getAssetsUnifyStateFeature( + messenger: TransactionPayControllerMessenger, +): boolean { + const state = messenger.call('RemoteFeatureFlagController:getState'); + const assetsUnifyState = state.remoteFeatureFlags.assetsUnifyState as + | { + enabled: boolean; + featureVersion: string | null; + } + | undefined; + + const AssetsUnifyStateFeatureVersion = '1'; + + return ( + Boolean(assetsUnifyState?.enabled) && + assetsUnifyState?.featureVersion === AssetsUnifyStateFeatureVersion + ); +} + /** * Get a value from a record using a case-insensitive key lookup. * diff --git a/packages/transaction-pay-controller/src/utils/token.test.ts b/packages/transaction-pay-controller/src/utils/token.test.ts index 46a0a03017b..e24338cca8d 100644 --- a/packages/transaction-pay-controller/src/utils/token.test.ts +++ b/packages/transaction-pay-controller/src/utils/token.test.ts @@ -10,13 +10,13 @@ import { getTokenBalance, getTokenInfo, getTokenFiatRate, - getAllTokenBalances, getNativeToken, isSameToken, getLiveTokenBalance, normalizeTokenAddress, TokenAddressTarget, } from './token'; +import { getDefaultRemoteFeatureFlagControllerState } from '../../../remote-feature-flag-controller/src/remote-feature-flag-controller'; import { CHAIN_ID_POLYGON, NATIVE_TOKEN_ADDRESS, @@ -50,6 +50,8 @@ const PROVIDER_MOCK = { request: jest.fn() }; describe('Token Utils', () => { const { messenger, + getAssetsControllerStateMock, + getRemoteFeatureFlagControllerStateMock, getTokensControllerStateMock, getNetworkClientByIdMock, getTokenBalanceControllerStateMock, @@ -68,6 +70,10 @@ describe('Token Utils', () => { mockBalanceOf = jest.fn(); mockGetBalance = jest.fn(); + getRemoteFeatureFlagControllerStateMock.mockReturnValue({ + ...getDefaultRemoteFeatureFlagControllerState(), + }); + findNetworkClientIdByChainIdMock.mockReturnValue(NETWORK_CLIENT_ID_MOCK); getNetworkClientByIdMock.mockReturnValue({ @@ -84,7 +90,44 @@ describe('Token Utils', () => { })); }); + function enableAssetsUnifyState(): void { + getRemoteFeatureFlagControllerStateMock.mockReturnValue({ + ...getDefaultRemoteFeatureFlagControllerState(), + remoteFeatureFlags: { + assetsUnifyState: { + enabled: true, + featureVersion: '1', + minimumVersion: null, + }, + }, + }); + } + describe('getTokenInfo', () => { + it('returns decimals and symbol from AssetsController when assets unify state feature is enabled', () => { + enableAssetsUnifyState(); + getAssetsControllerStateMock.mockReturnValue({ + allTokens: { + [CHAIN_ID_MOCK]: { + test123: [ + { + address: TOKEN_ADDRESS_MOCK.toLowerCase() as Hex, + decimals: DECIMALS_MOCK, + symbol: SYMBOL_MOCK, + }, + ], + }, + }, + }); + + const result = getTokenInfo(messenger, TOKEN_ADDRESS_MOCK, CHAIN_ID_MOCK); + + expect(result).toStrictEqual({ + decimals: DECIMALS_MOCK, + symbol: SYMBOL_MOCK, + }); + }); + it('returns decimals and symbol from controller state', () => { getTokensControllerStateMock.mockReturnValue({ allTokens: { @@ -192,6 +235,51 @@ describe('Token Utils', () => { }); describe('getTokenBalance', () => { + it('returns token balance from AssetsController when assets unify state feature is enabled', () => { + enableAssetsUnifyState(); + getAssetsControllerStateMock.mockReturnValue({ + tokenBalances: { + [FROM_MOCK]: { + [CHAIN_ID_MOCK]: { + [TOKEN_ADDRESS_MOCK]: BALANCE_MOCK, + }, + }, + }, + }); + + const result = getTokenBalance( + messenger, + FROM_MOCK, + CHAIN_ID_MOCK, + TOKEN_ADDRESS_MOCK.toLowerCase() as Hex, + ); + + expect(result).toBe('291'); + }); + + it('returns native balance from AssetsController when assets unify state feature is enabled', () => { + enableAssetsUnifyState(); + getAssetsControllerStateMock.mockReturnValue({ + tokenBalances: {}, + accountsByChainId: { + [CHAIN_ID_MOCK]: { + [FROM_MOCK]: { + balance: '0x123', + }, + }, + }, + }); + + const result = getTokenBalance( + messenger, + FROM_MOCK, + CHAIN_ID_MOCK, + NATIVE_TOKEN_ADDRESS, + ); + + expect(result).toBe('291'); + }); + it('returns balance from controller state', () => { getTokenBalanceControllerStateMock.mockReturnValue({ tokenBalances: { @@ -278,6 +366,37 @@ describe('Token Utils', () => { }); describe('getTokenFiatRate', () => { + it('returns fiat rates from AssetsController when assets unify state feature is enabled', () => { + enableAssetsUnifyState(); + + getAssetsControllerStateMock.mockReturnValue({ + marketData: { + [CHAIN_ID_MOCK]: { + [TOKEN_ADDRESS_MOCK]: { + price: 2.0, + }, + }, + }, + currencyRates: { + [TICKER_MOCK]: { + conversionRate: 3.0, + usdConversionRate: 4.0, + }, + }, + }); + + const result = getTokenFiatRate( + messenger, + TOKEN_ADDRESS_MOCK, + CHAIN_ID_MOCK, + ); + + expect(result).toStrictEqual({ + fiatRate: '6', + usdRate: '8', + }); + }); + it('returns fiat rates', () => { findNetworkClientIdByChainIdMock.mockReturnValue(NETWORK_CLIENT_ID_MOCK); @@ -630,55 +749,6 @@ describe('Token Utils', () => { }); }); - describe('getAllTokenBalances', () => { - it('returns all token balances including native token', () => { - getTokenBalanceControllerStateMock.mockReturnValue({ - tokenBalances: { - [FROM_MOCK]: { - '0x1': { - [TOKEN_ADDRESS_MOCK]: '0x10', - [TOKEN_ADDRESS_2_MOCK]: '0x20', - }, - '0x2': { - [TOKEN_ADDRESS_MOCK]: '0x30', - }, - }, - }, - }); - - getAccountTrackerControllerStateMock.mockReturnValue({ - accountsByChainId: { - '0x1': { - [FROM_MOCK]: { - balance: '0x40', - }, - }, - '0x2': { - [FROM_MOCK]: { - balance: '0x50', - }, - }, - '0x3': { - [FROM_MOCK]: { - balance: '0x60', - }, - }, - }, - }); - - const result = getAllTokenBalances(messenger, FROM_MOCK); - - expect(result).toStrictEqual([ - { chainId: '0x1', tokenAddress: TOKEN_ADDRESS_MOCK, balance: '16' }, - { chainId: '0x1', tokenAddress: TOKEN_ADDRESS_2_MOCK, balance: '32' }, - { chainId: '0x1', tokenAddress: NATIVE_TOKEN_ADDRESS, balance: '64' }, - { chainId: '0x2', tokenAddress: TOKEN_ADDRESS_MOCK, balance: '48' }, - { chainId: '0x2', tokenAddress: NATIVE_TOKEN_ADDRESS, balance: '80' }, - { chainId: '0x3', tokenAddress: NATIVE_TOKEN_ADDRESS, balance: '96' }, - ]); - }); - }); - describe('isSameToken', () => { it('returns true for same address and chain', () => { const token1 = { address: TOKEN_ADDRESS_MOCK, chainId: CHAIN_ID_MOCK }; diff --git a/packages/transaction-pay-controller/src/utils/token.ts b/packages/transaction-pay-controller/src/utils/token.ts index cc3d33f448c..405d576e635 100644 --- a/packages/transaction-pay-controller/src/utils/token.ts +++ b/packages/transaction-pay-controller/src/utils/token.ts @@ -1,11 +1,12 @@ import { Contract } from '@ethersproject/contracts'; import { Web3Provider } from '@ethersproject/providers'; +import { TokensControllerState } from '@metamask/assets-controllers'; import { toChecksumHexAddress } from '@metamask/controller-utils'; import { abiERC20 } from '@metamask/metamask-eth-abis'; import type { Hex } from '@metamask/utils'; import { BigNumber } from 'bignumber.js'; -import { uniq } from 'lodash'; +import { getAssetsUnifyStateFeature } from './feature-flags'; import { CHAIN_ID_POLYGON, NATIVE_TOKEN_ADDRESS, @@ -49,18 +50,32 @@ export function getTokenBalance( chainId: Hex, tokenAddress: Hex, ): string { - const tokenBalanceControllerState = messenger.call( - 'TokenBalancesController:getState', - ); + const assetsUnifyStateFeatureEnabled = getAssetsUnifyStateFeature(messenger); + + let tokenBalances; + let accountsByChainId; + if (assetsUnifyStateFeatureEnabled) { + const assetsControllerState = messenger.call( + 'AssetsController:getStateForTransactionPay', + ); + + tokenBalances = assetsControllerState?.tokenBalances; + accountsByChainId = assetsControllerState?.accountsByChainId; + } else { + tokenBalances = messenger.call( + 'TokenBalancesController:getState', + )?.tokenBalances; + accountsByChainId = messenger.call( + 'AccountTrackerController:getState', + )?.accountsByChainId; + } const normalizedAccount = account.toLowerCase() as Hex; const normalizedTokenAddress = toChecksumHexAddress(tokenAddress) as Hex; const isNative = normalizedTokenAddress === getNativeToken(chainId); const balanceHex = - tokenBalanceControllerState.tokenBalances?.[normalizedAccount]?.[chainId]?.[ - normalizedTokenAddress - ]; + tokenBalances?.[normalizedAccount]?.[chainId]?.[normalizedTokenAddress]; if (!isNative && balanceHex === undefined) { return '0'; @@ -70,12 +85,7 @@ export function getTokenBalance( return new BigNumber(balanceHex, 16).toString(10); } - const accountTrackerControllerState = messenger.call( - 'AccountTrackerController:getState', - ); - - const chainAccounts = - accountTrackerControllerState.accountsByChainId?.[chainId]; + const chainAccounts = accountsByChainId?.[chainId]; const checksumAccount = toChecksumHexAddress(normalizedAccount) as Hex; const nativeBalanceHex = chainAccounts?.[checksumAccount]?.balance as Hex; @@ -83,55 +93,6 @@ export function getTokenBalance( return new BigNumber(nativeBalanceHex ?? '0x0', 16).toString(10); } -/** - * Get the token balance for a specific account and token. - * - * @param messenger - Controller messenger. - * @param account - Address of the account. - * @returns The token balance as a BigNumber. - */ -export function getAllTokenBalances( - messenger: TransactionPayControllerMessenger, - account: Hex, -): { - balance: string; - chainId: Hex; - tokenAddress: Hex; -}[] { - const tokenBalanceControllerState = messenger.call( - 'TokenBalancesController:getState', - ); - - const accountTrackerControllerState = messenger.call( - 'AccountTrackerController:getState', - ); - - const nativeChainIds = Object.keys( - accountTrackerControllerState.accountsByChainId, - ) as Hex[]; - - const normalizedAccount = account.toLowerCase() as Hex; - - const balancesByTokenByChain = - tokenBalanceControllerState.tokenBalances?.[normalizedAccount]; - - const tokenChainIds = Object.keys(balancesByTokenByChain) as Hex[]; - const chainIds = uniq([...tokenChainIds, ...nativeChainIds]); - - return chainIds.flatMap((chainId) => { - const tokenAddresses = [ - ...(Object.keys(balancesByTokenByChain[chainId] ?? {}) as Hex[]), - getNativeToken(chainId), - ]; - - return tokenAddresses.map((tokenAddress) => ({ - chainId, - tokenAddress, - balance: getTokenBalance(messenger, account, chainId, tokenAddress), - })); - }); -} - /** * Get the token decimals for a specific token. * @@ -145,13 +106,23 @@ export function getTokenInfo( tokenAddress: Hex, chainId: Hex, ): { decimals: number; symbol: string } | undefined { - const controllerState = messenger.call('TokensController:getState'); + const assetsUnifyStateFeatureEnabled = getAssetsUnifyStateFeature(messenger); + + let allTokens: TokensControllerState['allTokens']; + if (assetsUnifyStateFeatureEnabled) { + allTokens = messenger.call( + 'AssetsController:getStateForTransactionPay', + )?.allTokens; + } else { + allTokens = messenger.call('TokensController:getState')?.allTokens; + } + const normalizedTokenAddress = tokenAddress.toLowerCase() as Hex; const isNative = normalizedTokenAddress === getNativeToken(chainId).toLowerCase(); - const token = Object.values(controllerState.allTokens?.[chainId] ?? {}) + const token = Object.values(allTokens?.[chainId] ?? {}) .flat() .find( (singleToken) => @@ -188,23 +159,35 @@ export function getTokenFiatRate( tokenAddress: Hex, chainId: Hex, ): FiatRates | undefined { + const assetsUnifyStateFeatureEnabled = getAssetsUnifyStateFeature(messenger); + + let marketData; + let currencyRates; + if (assetsUnifyStateFeatureEnabled) { + const assetsControllerState = messenger.call( + 'AssetsController:getStateForTransactionPay', + ); + + marketData = assetsControllerState?.marketData; + currencyRates = assetsControllerState?.currencyRates; + } else { + marketData = messenger.call('TokenRatesController:getState')?.marketData; + currencyRates = messenger.call( + 'CurrencyRateController:getState', + )?.currencyRates; + } + const ticker = getTicker(chainId, messenger); if (!ticker) { return undefined; } - const rateControllerState = messenger.call('TokenRatesController:getState'); - - const currencyRateControllerState = messenger.call( - 'CurrencyRateController:getState', - ); - const normalizedTokenAddress = toChecksumHexAddress(tokenAddress) as Hex; const isNative = normalizedTokenAddress === getNativeToken(chainId); const tokenToNativeRate = - rateControllerState.marketData?.[chainId]?.[normalizedTokenAddress]?.price; + marketData?.[chainId]?.[normalizedTokenAddress]?.price; if (tokenToNativeRate === undefined && !isNative) { return undefined; @@ -213,7 +196,7 @@ export function getTokenFiatRate( const { conversionRate: nativeToFiatRate, usdConversionRate: nativeToUsdRate, - } = currencyRateControllerState.currencyRates?.[ticker] ?? { + } = currencyRates?.[ticker] ?? { conversionRate: null, usdConversionRate: null, }; diff --git a/packages/transaction-pay-controller/tsconfig.build.json b/packages/transaction-pay-controller/tsconfig.build.json index aa8dd9cb92e..06516b30eb0 100644 --- a/packages/transaction-pay-controller/tsconfig.build.json +++ b/packages/transaction-pay-controller/tsconfig.build.json @@ -6,6 +6,9 @@ "rootDir": "./src" }, "references": [ + { + "path": "../assets-controller/tsconfig.build.json" + }, { "path": "../assets-controllers/tsconfig.build.json" }, diff --git a/packages/transaction-pay-controller/tsconfig.json b/packages/transaction-pay-controller/tsconfig.json index 0452cae3d20..5d8d3ad7be3 100644 --- a/packages/transaction-pay-controller/tsconfig.json +++ b/packages/transaction-pay-controller/tsconfig.json @@ -4,6 +4,9 @@ "baseUrl": "./" }, "references": [ + { + "path": "../assets-controller" + }, { "path": "../assets-controllers" }, diff --git a/yarn.lock b/yarn.lock index e54f57b6b9f..6ffd6ff456e 100644 --- a/yarn.lock +++ b/yarn.lock @@ -5420,6 +5420,7 @@ __metadata: "@ethersproject/abi": "npm:^5.7.0" "@ethersproject/contracts": "npm:^5.7.0" "@ethersproject/providers": "npm:^5.7.0" + "@metamask/assets-controller": "npm:^2.3.0" "@metamask/assets-controllers": "npm:^100.2.1" "@metamask/auto-changelog": "npm:^3.4.4" "@metamask/base-controller": "npm:^9.0.0"