diff --git a/source/adapters/cuda/async_alloc.cpp b/source/adapters/cuda/async_alloc.cpp index 6dade0803d..1cc48a80d2 100644 --- a/source/adapters/cuda/async_alloc.cpp +++ b/source/adapters/cuda/async_alloc.cpp @@ -42,7 +42,9 @@ UR_APIEXPORT ur_result_t urEnqueueUSMDeviceAllocExp( if (pPool) { assert(pPool->usesCudaPool()); - + UR_CHECK_ERROR( + cuMemAllocFromPoolAsync(reinterpret_cast(ppMem), size, + pPool->getCudaPool(), CuStream)); } else { UR_CHECK_ERROR(cuMemAllocAsync(reinterpret_cast(ppMem), size, CuStream)); diff --git a/source/adapters/cuda/usm.cpp b/source/adapters/cuda/usm.cpp index e717d17f29..2849981125 100644 --- a/source/adapters/cuda/usm.cpp +++ b/source/adapters/cuda/usm.cpp @@ -440,9 +440,50 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context, if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP)) throw; + CUmemPoolProps MemPoolProps; + + const void *pNext = PoolDesc->pNext; + while (pNext != nullptr) { + const ur_base_desc_t *BaseDesc = static_cast(pNext); + switch (BaseDesc->stype) { + case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: { + const ur_usm_pool_limits_desc_t *Limits = + reinterpret_cast(BaseDesc); + MemPoolProps.maxSize = Limits->maxPoolableSize; + std::ignore = Limits->minDriverAllocSize; // FIXME: We don't do anything + // with this. Can we/do we need + // to do something with this? + break; + } + default: { + throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT); + } + } + pNext = BaseDesc->pNext; + } + // TODO: what flags should be used here. Moreover what flags should have // UR counterparts? - UR_CHECK_ERROR(cuMemPoolCreate(&CUmemPool, 0)); + MemPoolProps.allocType = + CU_MEM_ALLOCATION_TYPE_PINNED; // Is this valid? CUDA docs say: + // + // "This allocation type is 'pinned', i.e. + // cannot migrate from its current + // location while the application is + // actively using it" + // + // Alternatives are *_INVALID (default) and + // *_MAX. + MemPoolProps.location.id = Device->getIndex(); // Docs are not clear on what + // this id is for. I am + // assuming it is used for + // device id. I have made a + // forum post here: + // https://forums.developer.nvidia.com/t/incomplete-description-in-cumemlocation-v1-struct-reference/318701 + MemPoolProps.location.type = + CU_MEM_LOCATION_TYPE_DEVICE; // Alternatives are: + // HOST, HOST_NUMA and HOST_NUMA_CURRENT + UR_CHECK_ERROR(cuMemPoolCreate(&CUmemPool, &MemPoolProps)); } bool ur_usm_pool_handle_t_::hasUMFPool(umf_memory_pool_t *umf_pool) {