-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add metal acceleration structures (#65)
Co-authored-by: Simon Kallweit <skallweit@nvidia.com>
- Loading branch information
1 parent
61c3b5c
commit e64b712
Showing
15 changed files
with
686 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
.DS_Store | ||
.vscode/ | ||
build/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
#pragma once | ||
|
||
#include <slang-rhi.h> | ||
|
||
namespace rhi { | ||
|
||
inline AccelerationStructureInstanceDescType getAccelerationStructureInstanceDescType(DeviceType deviceType) | ||
{ | ||
switch (deviceType) | ||
{ | ||
case DeviceType::D3D12: | ||
return AccelerationStructureInstanceDescType::D3D12; | ||
case DeviceType::Vulkan: | ||
return AccelerationStructureInstanceDescType::Vulkan; | ||
case DeviceType::Metal: | ||
return AccelerationStructureInstanceDescType::Metal; | ||
case DeviceType::CUDA: | ||
return AccelerationStructureInstanceDescType::Optix; | ||
} | ||
return AccelerationStructureInstanceDescType::Generic; | ||
} | ||
|
||
inline AccelerationStructureInstanceDescType getAccelerationStructureInstanceDescType(IDevice* device) | ||
{ | ||
return getAccelerationStructureInstanceDescType(device->getDeviceInfo().deviceType); | ||
} | ||
|
||
inline Size getAccelerationStructureInstanceDescSize(AccelerationStructureInstanceDescType type) | ||
{ | ||
switch (type) | ||
{ | ||
case AccelerationStructureInstanceDescType::Generic: | ||
return sizeof(AccelerationStructureInstanceDescGeneric); | ||
case AccelerationStructureInstanceDescType::D3D12: | ||
return sizeof(AccelerationStructureInstanceDescD3D12); | ||
case AccelerationStructureInstanceDescType::Vulkan: | ||
return sizeof(AccelerationStructureInstanceDescVulkan); | ||
case AccelerationStructureInstanceDescType::Optix: | ||
return sizeof(AccelerationStructureInstanceDescOptix); | ||
case AccelerationStructureInstanceDescType::Metal: | ||
return sizeof(AccelerationStructureInstanceDescMetal); | ||
} | ||
return 0; | ||
} | ||
|
||
inline void convertAccelerationStructureInstanceDesc( | ||
AccelerationStructureInstanceDescType dstType, | ||
void* dst, | ||
const AccelerationStructureInstanceDescGeneric* src | ||
) | ||
{ | ||
switch (dstType) | ||
{ | ||
case AccelerationStructureInstanceDescType::Generic: | ||
::memcpy(dst, src, sizeof(AccelerationStructureInstanceDescGeneric)); | ||
break; | ||
case AccelerationStructureInstanceDescType::D3D12: | ||
{ | ||
static_assert( | ||
sizeof(AccelerationStructureInstanceDescD3D12) == sizeof(AccelerationStructureInstanceDescGeneric) | ||
); | ||
auto dstD3D12 = reinterpret_cast<AccelerationStructureInstanceDescD3D12*>(dst); | ||
::memcpy(dstD3D12, src, sizeof(AccelerationStructureInstanceDescD3D12)); | ||
break; | ||
} | ||
case AccelerationStructureInstanceDescType::Vulkan: | ||
{ | ||
static_assert( | ||
sizeof(AccelerationStructureInstanceDescVulkan) == sizeof(AccelerationStructureInstanceDescGeneric) | ||
); | ||
auto dstVulkan = reinterpret_cast<AccelerationStructureInstanceDescVulkan*>(dst); | ||
::memcpy(dstVulkan, src, sizeof(AccelerationStructureInstanceDescVulkan)); | ||
break; | ||
} | ||
case AccelerationStructureInstanceDescType::Optix: | ||
{ | ||
auto dstOptix = reinterpret_cast<AccelerationStructureInstanceDescOptix*>(dst); | ||
::memcpy(dstOptix->transform, src->transform, 36); | ||
dstOptix->instanceId = src->instanceID; | ||
dstOptix->sbtOffset = src->instanceContributionToHitGroupIndex; | ||
dstOptix->visibilityMask = src->instanceMask; | ||
// Generic flags match the Optix flags. | ||
// TriangleFacingCullDisable -> OPTIX_INSTANCE_FLAG_DISABLE_TRIANGLE_FACE_CULLING | ||
// TriangleFrontCounterClockwise -> OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING | ||
// ForceOpaque -> OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT | ||
// NoOpaque -> OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT | ||
dstOptix->flags = (uint32_t)src->flags; | ||
dstOptix->traversableHandle = src->accelerationStructure.value; | ||
dstOptix->pad[0] = 0; | ||
dstOptix->pad[1] = 0; | ||
break; | ||
} | ||
case AccelerationStructureInstanceDescType::Metal: | ||
{ | ||
auto dstMetal = reinterpret_cast<AccelerationStructureInstanceDescMetal*>(dst); | ||
// Transpose the transform matrix. | ||
for (int i = 0; i < 3; ++i) | ||
{ | ||
for (int j = 0; j < 4; ++j) | ||
{ | ||
dstMetal->transform[j][i] = src->transform[i][j]; | ||
} | ||
} | ||
// Generic flags match the Metal options. | ||
// TriangleFacingCullDisable -> DisableTriangleCulling | ||
// TriangleFrontCounterClockwise -> TriangleFrontFacingWindingCounterClockwise | ||
// ForceOpaque -> Opaque | ||
// NoOpaque -> NonOpaque | ||
dstMetal->options = (uint32_t)src->flags; | ||
dstMetal->mask = src->instanceMask; | ||
dstMetal->intersectionFunctionTableOffset = src->instanceContributionToHitGroupIndex; | ||
dstMetal->accelerationStructureIndex = src->accelerationStructure.value; | ||
dstMetal->userID = src->instanceID; | ||
break; | ||
} | ||
} | ||
} | ||
|
||
inline void convertAccelerationStructureInstanceDescs( | ||
GfxCount count, | ||
AccelerationStructureInstanceDescType dstType, | ||
void* dst, | ||
Size dstStride, | ||
const AccelerationStructureInstanceDescGeneric* src, | ||
Size srcStride | ||
) | ||
{ | ||
for (GfxIndex i = 0; i < count; ++i) | ||
{ | ||
convertAccelerationStructureInstanceDesc(dstType, dst, src); | ||
dst = (uint8_t*)dst + dstStride; | ||
src = (const AccelerationStructureInstanceDescGeneric*)((const uint8_t*)src + srcStride); | ||
} | ||
} | ||
|
||
} // namespace rhi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#include "metal-acceleration-structure.h" | ||
#include "metal-device.h" | ||
|
||
namespace rhi::metal { | ||
|
||
AccelerationStructureImpl::~AccelerationStructureImpl() | ||
{ | ||
m_device->m_accelerationStructures.freeList.push_back(m_globalIndex); | ||
m_device->m_accelerationStructures.list[m_globalIndex] = nullptr; | ||
m_device->m_accelerationStructures.dirty = true; | ||
} | ||
|
||
Result AccelerationStructureImpl::getNativeHandle(NativeHandle* outHandle) | ||
{ | ||
outHandle->type = NativeHandleType::MTLAccelerationStructure; | ||
outHandle->value = (uint64_t)m_accelerationStructure.get(); | ||
return SLANG_OK; | ||
} | ||
|
||
AccelerationStructureHandle AccelerationStructureImpl::getHandle() | ||
{ | ||
return AccelerationStructureHandle{m_globalIndex}; | ||
} | ||
|
||
DeviceAddress AccelerationStructureImpl::getDeviceAddress() | ||
{ | ||
return 0; | ||
} | ||
|
||
} // namespace rhi::metal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
#pragma once | ||
|
||
#include "metal-base.h" | ||
|
||
namespace rhi::metal { | ||
|
||
class AccelerationStructureImpl : public AccelerationStructure | ||
{ | ||
public: | ||
DeviceImpl* m_device; | ||
NS::SharedPtr<MTL::AccelerationStructure> m_accelerationStructure; | ||
uint32_t m_globalIndex; | ||
|
||
public: | ||
AccelerationStructureImpl(DeviceImpl* device, const AccelerationStructureDesc& desc) | ||
: AccelerationStructure(desc) | ||
, m_device(device) | ||
{ | ||
} | ||
|
||
~AccelerationStructureImpl(); | ||
|
||
// IAccelerationStructure implementation | ||
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(NativeHandle* outHandle) override; | ||
virtual SLANG_NO_THROW AccelerationStructureHandle getHandle() override; | ||
virtual SLANG_NO_THROW DeviceAddress SLANG_MCALL getDeviceAddress() override; | ||
}; | ||
|
||
} // namespace rhi::metal |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.