Skip to content

Commit

Permalink
Fix compilation
Browse files Browse the repository at this point in the history
  • Loading branch information
hdelan committed Dec 20, 2024
1 parent 7cc5b39 commit 6896037
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 66 deletions.
107 changes: 41 additions & 66 deletions source/adapters/cuda/usm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -379,67 +379,6 @@ ur_result_t USMHostMemoryProvider::allocateImpl(void **ResultPtr, size_t Size,
ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc)
: Context{Context} {
if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
// TODO: this should only use the host
}
const void *pNext = PoolDesc->pNext;
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
switch (BaseDesc->stype) {
case UR_STRUCTURE_TYPE_USM_POOL_LIMITS_DESC: {
const ur_usm_pool_limits_desc_t *Limits =
reinterpret_cast<const ur_usm_pool_limits_desc_t *>(BaseDesc);
for (auto &config : DisjointPoolConfigs.Configs) {
config.MaxPoolableSize = Limits->maxPoolableSize;
config.SlabMinSize = Limits->minDriverAllocSize;
}
break;
}
default: {
throw UsmAllocationException(UR_RESULT_ERROR_INVALID_ARGUMENT);
}
}
pNext = BaseDesc->pNext;
}

auto MemProvider =
umf::memoryProviderMakeUnique<USMHostMemoryProvider>(Context, nullptr)
.second;

HostMemPool =
umf::poolMakeUniqueFromOps(
umfDisjointPoolOps(), std::move(MemProvider),
&this->DisjointPoolConfigs.Configs[usm::DisjointPoolMemType::Host])
.second;

for (const auto &Device : Context->getDevices()) {
MemProvider =
umf::memoryProviderMakeUnique<USMDeviceMemoryProvider>(Context, Device)
.second;
DeviceMemPool = umf::poolMakeUniqueFromOps(
umfDisjointPoolOps(), std::move(MemProvider),
&this->DisjointPoolConfigs
.Configs[usm::DisjointPoolMemType::Device])
.second;
MemProvider =
umf::memoryProviderMakeUnique<USMSharedMemoryProvider>(Context, Device)
.second;
SharedMemPool = umf::poolMakeUniqueFromOps(
umfDisjointPoolOps(), std::move(MemProvider),
&this->DisjointPoolConfigs
.Configs[usm::DisjointPoolMemType::Shared])
.second;
Context->addPool(this);
}
}

ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_device_handle_t Device,
ur_usm_pool_desc_t *PoolDesc)
: Context{Context} {
if (PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP) {
// TODO: this should only use the host
}
const void *pNext = PoolDesc->pNext;
while (pNext != nullptr) {
const ur_base_desc_t *BaseDesc = static_cast<const ur_base_desc_t *>(pNext);
Expand Down Expand Up @@ -494,6 +433,18 @@ ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
}
}

ur_usm_pool_handle_t_::ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_device_handle_t Device,
ur_usm_pool_desc_t *PoolDesc)
: Context{Context}, Device{Device} {
if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP))
throw;

// TODO: what flags should be used here. Moreover what flags should have
// UR counterparts?
UR_CHECK_ERROR(cuMemPoolCreate(&CUmemPool, 0));
}

bool ur_usm_pool_handle_t_::hasUMFPool(umf_memory_pool_t *umf_pool) {
return DeviceMemPool.get() == umf_pool || SharedMemPool.get() == umf_pool ||
HostMemPool.get() == umf_pool;
Expand All @@ -507,11 +458,7 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
ur_usm_pool_handle_t *Pool ///< [out] pointer to USM memory pool
) {
// Without pool tracking we can't free pool allocations.
#ifndef UMF_ENABLE_POOL_TRACKING
// We don't need UMF to use native mem pools
if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP))
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
#ifdef UMF_ENABLE_POOL_TRACKING
if (PoolDesc->flags & UR_USM_POOL_FLAG_ZERO_INITIALIZE_BLOCK) {
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
}
Expand All @@ -526,6 +473,34 @@ UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreate(
return UR_RESULT_ERROR_UNKNOWN;
}
return UR_RESULT_SUCCESS;
#else
std::ignore = Context;
std::ignore = PoolDesc;
std::ignore = Pool;
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
#endif
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolCreateExp(
ur_context_handle_t Context, ///< [in] handle of the context object
ur_device_handle_t Device, ///< [in] handle of the device object
ur_usm_pool_desc_t *PoolDesc, ///< [in] pointer to USM pool descriptor.
///< Can be chained with
///< ::ur_usm_pool_limits_desc_t
ur_usm_pool_handle_t *Pool ///< [out] pointer to USM memory pool
) {
// This entry point only supports native mem pools
if (!(PoolDesc->flags & UR_USM_POOL_FLAG_USE_NATIVE_MEMORY_POOL_EXP))
return UR_RESULT_ERROR_UNSUPPORTED_FEATURE;
try {
*Pool = reinterpret_cast<ur_usm_pool_handle_t>(
new ur_usm_pool_handle_t_(Context, Device, PoolDesc));
} catch (ur_result_t err) {
return err;
} catch (...) {
return UR_RESULT_ERROR_UNKNOWN;
}
return UR_RESULT_SUCCESS;
}

UR_APIEXPORT ur_result_t UR_APICALL urUSMPoolRetain(
Expand Down
5 changes: 5 additions & 0 deletions source/adapters/cuda/usm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ struct ur_usm_pool_handle_t_ {
std::atomic_uint32_t RefCount = 1;

ur_context_handle_t Context = nullptr;
ur_device_handle_t Device = nullptr;

usm::DisjointPoolAllConfigs DisjointPoolConfigs =
usm::DisjointPoolAllConfigs();
Expand All @@ -34,6 +35,10 @@ struct ur_usm_pool_handle_t_ {
ur_usm_pool_handle_t_(ur_context_handle_t Context,
ur_usm_pool_desc_t *PoolDesc);

// TODO: do we need the context param?
ur_usm_pool_handle_t_(ur_context_handle_t Context, ur_device_handle_t Device,
ur_usm_pool_desc_t *PoolDesc);

uint32_t incrementReferenceCount() noexcept { return ++RefCount; }

uint32_t decrementReferenceCount() noexcept { return --RefCount; }
Expand Down

0 comments on commit 6896037

Please sign in to comment.