diff --git a/react/src/components/ResourceAllocationFormItems.tsx b/react/src/components/ResourceAllocationFormItems.tsx index 5d52127d94..7d06023e48 100644 --- a/react/src/components/ResourceAllocationFormItems.tsx +++ b/react/src/components/ResourceAllocationFormItems.tsx @@ -24,7 +24,6 @@ import { } from './ImageEnvironmentSelectFormItems'; import InputNumberWithSlider from './InputNumberWithSlider'; import ResourceGroupSelectForCurrentProject from './ResourceGroupSelectForCurrentProject'; -import { ACCELERATOR_UNIT_MAP } from './ResourceNumber'; import ResourcePresetSelect from './ResourcePresetSelect'; import { CaretDownOutlined } from '@ant-design/icons'; import { @@ -928,7 +927,7 @@ const ResourceAllocationFormItems: React.FC< }, tooltip: { formatter: (value = 0) => { - return `${value} ${ACCELERATOR_UNIT_MAP[currentAcceleratorType]}`; + return `${value} ${resourceSlotsDetails?.[currentAcceleratorType]?.display_unit || ''}`; }, open: currentImageAcceleratorLimits.length <= 0 @@ -992,7 +991,8 @@ const ResourceAllocationFormItems: React.FC< return { value: name, label: - ACCELERATOR_UNIT_MAP[name] || 'UNIT', + resourceSlotsDetails?.[name] + ?.display_unit || 'UNIT', disabled: currentImageAcceleratorLimits.length > 0 && diff --git a/react/src/components/ResourceNumber.tsx b/react/src/components/ResourceNumber.tsx index 9670535791..a54e7cef23 100644 --- a/react/src/components/ResourceNumber.tsx +++ b/react/src/components/ResourceNumber.tsx @@ -1,32 +1,21 @@ import { iSizeToSize } from '../helper'; -import { useResourceSlotsDetails } from '../hooks/backendai'; +import { + BaseResourceSlotName, + KnownAcceleratorResourceSlotName, + ResourceSlotName, + useResourceSlotsDetails, +} from '../hooks/backendai'; import { useCurrentResourceGroupValue } from '../hooks/useCurrentProject'; import Flex from './Flex'; import { Tooltip, Typography, theme } from 'antd'; import _ from 'lodash'; import React, { ReactElement } from 'react'; -import { useTranslation } from 'react-i18next'; - -export const ACCELERATOR_UNIT_MAP: { - [key: string]: string; -} = { - 'cuda.device': 'GPU', - 'cuda.shares': 'FGPU', - 'rocm.device': 'GPU', - 'tpu.device': 'TPU', - 'ipu.device': 'IPU', - 'atom.device': 'ATOM', - 'atom-plus.device': 'ATOM+', - 'gaudi2.device': 'Gaudi 2', - 'warboy.device': 'Warboy', - 'hyperaccel-lpu.device': 'Hyperaccel LPU', -}; export type ResourceOpts = { shmem?: number; }; interface ResourceNumberProps { - type: string; + type: ResourceSlotName | string; extra?: ReactElement; opts?: ResourceOpts; value: string; @@ -35,7 +24,9 @@ interface ResourceNumberProps { } type ResourceTypeInfo = { - [key in string]: V; + [key in KnownAcceleratorResourceSlotName]: V; +} & { + [key in BaseResourceSlotName]: V; }; const ResourceNumber: React.FC = ({ type, @@ -112,7 +103,7 @@ const MWCIconWrap: React.FC<{ size?: number; children: string }> = ({ }; interface AccTypeIconProps extends Omit, 'src'> { - type: string; + type: ResourceSlotName | string; showIcon?: boolean; showUnit?: boolean; showTooltip?: boolean; @@ -126,33 +117,24 @@ export const ResourceTypeIcon: React.FC = ({ showTooltip = true, ...props }) => { - const { t } = useTranslation(); - - const resourceTypeIconSrcMap: ResourceTypeInfo< - [ReactElement | string, string] - > = { - cpu: [ - developer_board, - t('session.core'), - ], - mem: [memory, 'GiB'], - 'cuda.device': ['/resources/icons/file_type_cuda.svg', 'GPU'], - 'cuda.shares': ['/resources/icons/file_type_cuda.svg', 'FGPU'], - 'rocm.device': ['/resources/icons/rocm.svg', 'GPU'], - 'tpu.device': [view_module, 'TPU'], - 'ipu.device': [view_module, 'IPU'], - 'atom.device': ['/resources/icons/rebel.svg', 'ATOM'], - 'atom-plus.device': ['/resources/icons/rebel.svg', 'ATOM+'], - 'gaudi2.device': ['/resources/icons/gaudi.svg', 'Gaudi 2'], - 'warboy.device': ['/resources/icons/furiosa.svg', 'Warboy'], - 'hyperaccel-lpu.device': [ - '/resources/icons/npu_generic.svg', - 'Hyperaccel LPU', - ], + const resourceTypeIconSrcMap: ResourceTypeInfo = { + cpu: developer_board, + mem: memory, + 'cuda.device': '/resources/icons/file_type_cuda.svg', + 'cuda.shares': '/resources/icons/file_type_cuda.svg', + 'rocm.device': '/resources/icons/rocm.svg', + 'tpu.device': view_module, + 'ipu.device': view_module, + 'atom.device': '/resources/icons/rebel.svg', + 'atom-plus.device': '/resources/icons/rebel.svg', + 'gaudi2.device': '/resources/icons/gaudi.svg', + 'warboy.device': '/resources/icons/furiosa.svg', + 'hyperaccel-lpu.device': '/resources/icons/npu_generic.svg', }; const content = - typeof resourceTypeIconSrcMap[type]?.[0] === 'string' ? ( + typeof resourceTypeIconSrcMap[type as KnownAcceleratorResourceSlotName] === + 'string' ? ( = ({ ...(props.style || {}), }} // @ts-ignore - src={resourceTypeIconSrcMap[type]?.[0] || ''} + src={resourceTypeIconSrcMap[type] || ''} alt={type} /> ) : ( - {resourceTypeIconSrcMap[type]?.[0] || type} + {resourceTypeIconSrcMap[type as KnownAcceleratorResourceSlotName] || + type} ); return showTooltip ? ( - // {content} ) : ( {content} diff --git a/react/src/components/ResourcePresetSelect.tsx b/react/src/components/ResourcePresetSelect.tsx index fc12bc1ad3..e293a4a182 100644 --- a/react/src/components/ResourcePresetSelect.tsx +++ b/react/src/components/ResourcePresetSelect.tsx @@ -1,6 +1,6 @@ import { localeCompare } from '../helper'; import { useUpdatableState } from '../hooks'; -import { useResourceSlots } from '../hooks/backendai'; +import { ResourceSlotName, useResourceSlots } from '../hooks/backendai'; import useControllableState from '../hooks/useControllableState'; import Flex from './Flex'; import ResourceNumber from './ResourceNumber'; @@ -123,7 +123,7 @@ const ResourcePresetSelect: React.FC = ({ // @ts-ignore options: _.map(resource_presets, (preset, index) => { const slotsInfo: { - [key in string]: string; + [key in ResourceSlotName]: string; } = JSON.parse(preset?.resource_slots); const disabled = allocatablePresetNames ? !allocatablePresetNames.includes(preset?.name || '') @@ -145,7 +145,7 @@ const ResourcePresetSelect: React.FC = ({ > {_.map( _.omitBy(slotsInfo, (slot, key) => - _.isEmpty(resourceSlots[key]), + _.isEmpty(resourceSlots[key as ResourceSlotName]), ), (slot, key) => { return ( diff --git a/react/src/components/ServiceLauncherPageContent.tsx b/react/src/components/ServiceLauncherPageContent.tsx index 14a157cb2b..cc52606753 100644 --- a/react/src/components/ServiceLauncherPageContent.tsx +++ b/react/src/components/ServiceLauncherPageContent.tsx @@ -9,6 +9,7 @@ import { useSuspendedBackendaiClient, useWebUINavigate, } from '../hooks'; +import { KnownAcceleratorResourceSlotName } from '../hooks/backendai'; import { useSuspenseTanQuery, useTanMutation } from '../hooks/reactQueryAlias'; import BAIModal, { DEFAULT_BAI_MODAL_Z_INDEX } from './BAIModal'; import EnvVarFormList, { EnvVarFormListValue } from './EnvVarFormList'; @@ -57,20 +58,12 @@ interface ServiceCreateConfigResourceOptsType { shmem?: number | string; } -interface ServiceCreateConfigResourceType { +type ServiceCreateConfigResourceType = { cpu: number | string; mem: string; - 'cuda.device'?: number | string; - 'cuda.shares'?: number | string; - 'rocm.device'?: number | string; - 'tpu.device'?: number | string; - 'ipu.device'?: number | string; - 'atom.device'?: number | string; - 'gaudi2.device'?: number | string; - 'atom-plus.device'?: number | string; - 'warboy.device'?: number | string; - 'hyperaccel-lpu.device'?: number | string; -} +} & { + [key in KnownAcceleratorResourceSlotName]?: number | string; +}; export interface MountOptionType { mount_destination?: string; type?: string; diff --git a/react/src/hooks/backendai.tsx b/react/src/hooks/backendai.tsx index 50d9dc42d6..dbef4ec780 100644 --- a/react/src/hooks/backendai.tsx +++ b/react/src/hooks/backendai.tsx @@ -8,6 +8,22 @@ import { import _ from 'lodash'; import { useEffect, useState } from 'react'; +export type BaseResourceSlotName = 'cpu' | 'mem'; +export type KnownAcceleratorResourceSlotName = + | 'cuda.device' + | 'cuda.shares' + | 'rocm.device' + | 'tpu.device' + | 'ipu.device' + | 'atom.device' + | 'atom-plus.device' + | 'gaudi2.device' + | 'warboy.device' + | 'hyperaccel-lpu.device'; + +export type ResourceSlotName = + | BaseResourceSlotName + | KnownAcceleratorResourceSlotName; export interface QuotaScope { id: string; quota_scope_id: string; @@ -23,19 +39,7 @@ export const useResourceSlots = () => { const [key, checkUpdate] = useUpdatableState('first'); const baiClient = useSuspendedBackendaiClient(); const { data: resourceSlots } = useSuspenseTanQuery<{ - cpu?: string; - mem?: string; - 'cuda.shares'?: string; - 'cuda.device'?: string; - 'rocm.device'?: string; - 'tpu.device'?: string; - 'ipu.device'?: string; - 'atom.device'?: string; - 'atom-plus.device'?: string; - 'gaudi2.device'?: string; - 'warboy.device'?: string; - 'hyperaccel-lpu.device'?: string; - [key: string]: string | undefined; + [key in ResourceSlotName]?: string; }>({ queryKey: ['useResourceSlots', key], queryFn: () => { diff --git a/react/src/hooks/useResourceLimitAndRemaining.test.ts b/react/src/hooks/useResourceLimitAndRemaining.test.ts new file mode 100644 index 0000000000..7fb3a5e0ff --- /dev/null +++ b/react/src/hooks/useResourceLimitAndRemaining.test.ts @@ -0,0 +1,49 @@ +import { isMatchingMaxPerContainer } from './useResourceLimitAndRemaining'; +import exp from 'constants'; + +describe('getConfigName', () => { + test('should match unknown devices', () => { + expect( + isMatchingMaxPerContainer('maxCUDADevicesPerContainer', 'cuda.device'), + ).toBe(true); + expect( + isMatchingMaxPerContainer('maxCUDASharesPerContainer', 'cuda.shares'), + ).toBe(true); + expect( + isMatchingMaxPerContainer('maxROCMDevicesPerContainer', 'rocm.device'), + ).toBe(true); + expect( + isMatchingMaxPerContainer('maxTPUDevicesPerContainer', 'tpu.device'), + ).toBe(true); + expect( + isMatchingMaxPerContainer('maxIPUDevicesPerContainer', 'ipu.device'), + ).toBe(true); + expect( + isMatchingMaxPerContainer('maxATOMDevicesPerContainer', 'atom.device'), + ).toBe(true); + expect( + isMatchingMaxPerContainer( + 'maxATOMPLUSDevicesPerContainer', + 'atom-plus.device', + ), + ).toBe(true); + expect( + isMatchingMaxPerContainer( + 'maxGaudi2DevicesPerContainer', + 'gaudi2.device', + ), + ).toBe(true); + expect( + isMatchingMaxPerContainer( + 'maxWarboyDevicesPerContainer', + 'warboy.device', + ), + ).toBe(true); + expect( + isMatchingMaxPerContainer( + 'maxHyperaccelLPUDevicesPerContainer', + 'hyperaccel-lpu.device', + ), + ).toBe(true); + }); +}); diff --git a/react/src/hooks/useResourceLimitAndRemaining.tsx b/react/src/hooks/useResourceLimitAndRemaining.tsx index 8416102579..4da0a17674 100644 --- a/react/src/hooks/useResourceLimitAndRemaining.tsx +++ b/react/src/hooks/useResourceLimitAndRemaining.tsx @@ -5,7 +5,24 @@ import { addNumberWithUnits, iSizeToSize } from '../helper'; import { useResourceSlots } from '../hooks/backendai'; import { useSuspenseTanQuery } from './reactQueryAlias'; import _ from 'lodash'; +import { useMemo } from 'react'; +const maxPerContainerRegex = /^max([A-Za-z0-9]+)PerContainer$/; + +export const isMatchingMaxPerContainer = (configName: string, key: string) => { + const match = configName.match(maxPerContainerRegex); + if (match) { + const configLowerCase = match[1].toLowerCase(); + const keyLowerCase = key.replaceAll(/[.-]/g, '').toLowerCase(); + // Because some accelerator names are not the same as the config name, we need to check if the config name is a substring of the accelerator name + // cuda.shares => maxCUDASharesPerContainer + // cuda.device => maxCUDADevicesPerContainer (Not maxCUDADevicePerContainer) + return ( + configLowerCase === keyLowerCase || configLowerCase === keyLowerCase + 's' + ); + } + return false; +}; export interface MergedResourceLimits { accelerators: { [key: string]: @@ -195,6 +212,14 @@ export const useResourceLimitAndRemaining = ({ }, ), }; + const perContainerConfigs = useMemo( + () => + _.omitBy(baiClient._config, (value, key) => { + return !maxPerContainerRegex.test(key); + }), + [baiClient._config], + ); + const resourceLimits: MergedResourceLimits = { cpu: resourceSlots?.cpu === undefined @@ -265,19 +290,11 @@ export const useResourceLimitAndRemaining = ({ accelerators: _.reduce( acceleratorSlots, (result, value, key) => { - const configName = - { - 'cuda.device': 'maxCUDADevicesPerContainer', - 'cuda.shares': 'maxCUDASharesPerContainer', - 'rocm.device': 'maxROCMDevicesPerContainer', - 'tpu.device': 'maxTPUDevicesPerContainer', - 'ipu.device': 'maxIPUDevicesPerContainer', - 'atom.device': 'maxATOMDevicesPerContainer', - 'atom-plus.device': 'maxATOMPlusDevicesPerContainer', - 'gaudi2.device': 'maxGaudi2DevicesPerContainer', - 'warboy.device': 'maxWarboyDevicesPerContainer', - 'hyperaccel-lpu.device': 'maxHyperaccelLPUDevicesPerContainer', // FIXME: add maxLPUDevicesPerContainer to config - }[key] || 'cuda.device'; // FIXME: temporally `cuda.device` config, when undefined + const perContainerLimit = + _.find(perContainerConfigs, (configValue, configName) => { + return isMatchingMaxPerContainer(configName, key); + }) ?? baiClient._config['cuda.device']; // FIXME: temporally `cuda.device` config, when undefined + result[key] = { min: parseInt( _.filter( @@ -288,7 +305,7 @@ export const useResourceLimitAndRemaining = ({ )?.[0]?.min || '0', ), max: _.min([ - baiClient._config[configName] || 8, + perContainerLimit || 8, // scaling group all cpu (using + remaining), string type resourceGroupResourceSize.accelerators[key], ]),