Skip to content

Commit

Permalink
add metal acceleration structures (#65)
Browse files Browse the repository at this point in the history
Co-authored-by: Simon Kallweit <skallweit@nvidia.com>
  • Loading branch information
westlicht and skallweitNV authored Oct 2, 2024
1 parent 61c3b5c commit e64b712
Show file tree
Hide file tree
Showing 15 changed files with 686 additions and 40 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.DS_Store
.vscode/
build/
62 changes: 59 additions & 3 deletions include/slang-rhi.h
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,7 @@ enum class NativeHandleType
MTLRenderPipelineState = 0x00040007,
MTLSharedEvent = 0x00040008,
MTLSamplerState = 0x00040009,
MTLAccelerationStructure = 0x0004000a,

CUdevice = 0x00050001,
CUdeviceptr = 0x00050002,
Expand Down Expand Up @@ -820,9 +821,19 @@ enum class AccelerationStructureInstanceFlags : uint32_t
};
SLANG_RHI_ENUM_CLASS_OPERATORS(AccelerationStructureInstanceFlags);

// The layout of this struct is intentionally consistent with D3D12_RAYTRACING_INSTANCE_DESC
// and VkAccelerationStructureInstanceKHR.
struct AccelerationStructureInstanceDesc
enum class AccelerationStructureInstanceDescType
{
Generic,
D3D12,
Vulkan,
Optix,
Metal
};

/// Generic instance descriptor.
/// The layout of this struct is intentionally consistent with D3D12_RAYTRACING_INSTANCE_DESC
/// and VkAccelerationStructureInstanceKHR for fast conversion.
struct AccelerationStructureInstanceDescGeneric
{
float transform[3][4];
uint32_t instanceID : 24;
Expand All @@ -832,6 +843,51 @@ struct AccelerationStructureInstanceDesc
AccelerationStructureHandle accelerationStructure;
};

/// Instance descriptor matching D3D12_RAYTRACING_INSTANCE_DESC.
struct AccelerationStructureInstanceDescD3D12
{
float Transform[3][4];
uint32_t InstanceID : 24;
uint32_t InstanceMask : 8;
uint32_t InstanceContributionToHitGroupIndex : 24;
uint32_t Flags : 8;
uint64_t AccelerationStructure;
};

/// Instance descriptor matching VkAccelerationStructureInstanceKHR.
struct AccelerationStructureInstanceDescVulkan
{
float transform[4][3];
uint32_t instanceCustomIndex : 24;
uint32_t mask : 8;
uint32_t instanceShaderBindingTableRecordOffset : 24;
uint32_t flags : 8;
uint64_t accelerationStructureReference;
};

/// Instance descriptor matching OptixInstance.
struct AccelerationStructureInstanceDescOptix
{
float transform[3][4];
uint32_t instanceId;
uint32_t sbtOffset;
uint32_t visibilityMask;
uint32_t flags;
uint64_t traversableHandle;
uint32_t pad[2];
};

/// Instance descriptor matching MTLAccelerationStructureUserIDInstanceDescriptor.
struct AccelerationStructureInstanceDescMetal
{
float transform[4][3];
uint32_t options;
uint32_t mask;
uint32_t intersectionFunctionTableOffset;
uint32_t accelerationStructureIndex;
uint32_t userID;
};

struct AccelerationStructureAABB
{
float minX;
Expand Down
136 changes: 136 additions & 0 deletions include/slang-rhi/acceleration-structure-utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#pragma once

#include <slang-rhi.h>

namespace rhi {

inline AccelerationStructureInstanceDescType getAccelerationStructureInstanceDescType(DeviceType deviceType)
{
switch (deviceType)
{
case DeviceType::D3D12:
return AccelerationStructureInstanceDescType::D3D12;
case DeviceType::Vulkan:
return AccelerationStructureInstanceDescType::Vulkan;
case DeviceType::Metal:
return AccelerationStructureInstanceDescType::Metal;
case DeviceType::CUDA:
return AccelerationStructureInstanceDescType::Optix;
}
return AccelerationStructureInstanceDescType::Generic;
}

inline AccelerationStructureInstanceDescType getAccelerationStructureInstanceDescType(IDevice* device)
{
return getAccelerationStructureInstanceDescType(device->getDeviceInfo().deviceType);
}

inline Size getAccelerationStructureInstanceDescSize(AccelerationStructureInstanceDescType type)
{
switch (type)
{
case AccelerationStructureInstanceDescType::Generic:
return sizeof(AccelerationStructureInstanceDescGeneric);
case AccelerationStructureInstanceDescType::D3D12:
return sizeof(AccelerationStructureInstanceDescD3D12);
case AccelerationStructureInstanceDescType::Vulkan:
return sizeof(AccelerationStructureInstanceDescVulkan);
case AccelerationStructureInstanceDescType::Optix:
return sizeof(AccelerationStructureInstanceDescOptix);
case AccelerationStructureInstanceDescType::Metal:
return sizeof(AccelerationStructureInstanceDescMetal);
}
return 0;
}

inline void convertAccelerationStructureInstanceDesc(
AccelerationStructureInstanceDescType dstType,
void* dst,
const AccelerationStructureInstanceDescGeneric* src
)
{
switch (dstType)
{
case AccelerationStructureInstanceDescType::Generic:
::memcpy(dst, src, sizeof(AccelerationStructureInstanceDescGeneric));
break;
case AccelerationStructureInstanceDescType::D3D12:
{
static_assert(
sizeof(AccelerationStructureInstanceDescD3D12) == sizeof(AccelerationStructureInstanceDescGeneric)
);
auto dstD3D12 = reinterpret_cast<AccelerationStructureInstanceDescD3D12*>(dst);
::memcpy(dstD3D12, src, sizeof(AccelerationStructureInstanceDescD3D12));
break;
}
case AccelerationStructureInstanceDescType::Vulkan:
{
static_assert(
sizeof(AccelerationStructureInstanceDescVulkan) == sizeof(AccelerationStructureInstanceDescGeneric)
);
auto dstVulkan = reinterpret_cast<AccelerationStructureInstanceDescVulkan*>(dst);
::memcpy(dstVulkan, src, sizeof(AccelerationStructureInstanceDescVulkan));
break;
}
case AccelerationStructureInstanceDescType::Optix:
{
auto dstOptix = reinterpret_cast<AccelerationStructureInstanceDescOptix*>(dst);
::memcpy(dstOptix->transform, src->transform, 36);
dstOptix->instanceId = src->instanceID;
dstOptix->sbtOffset = src->instanceContributionToHitGroupIndex;
dstOptix->visibilityMask = src->instanceMask;
// Generic flags match the Optix flags.
// TriangleFacingCullDisable -> OPTIX_INSTANCE_FLAG_DISABLE_TRIANGLE_FACE_CULLING
// TriangleFrontCounterClockwise -> OPTIX_INSTANCE_FLAG_FLIP_TRIANGLE_FACING
// ForceOpaque -> OPTIX_INSTANCE_FLAG_DISABLE_ANYHIT
// NoOpaque -> OPTIX_INSTANCE_FLAG_ENFORCE_ANYHIT
dstOptix->flags = (uint32_t)src->flags;
dstOptix->traversableHandle = src->accelerationStructure.value;
dstOptix->pad[0] = 0;
dstOptix->pad[1] = 0;
break;
}
case AccelerationStructureInstanceDescType::Metal:
{
auto dstMetal = reinterpret_cast<AccelerationStructureInstanceDescMetal*>(dst);
// Transpose the transform matrix.
for (int i = 0; i < 3; ++i)
{
for (int j = 0; j < 4; ++j)
{
dstMetal->transform[j][i] = src->transform[i][j];
}
}
// Generic flags match the Metal options.
// TriangleFacingCullDisable -> DisableTriangleCulling
// TriangleFrontCounterClockwise -> TriangleFrontFacingWindingCounterClockwise
// ForceOpaque -> Opaque
// NoOpaque -> NonOpaque
dstMetal->options = (uint32_t)src->flags;
dstMetal->mask = src->instanceMask;
dstMetal->intersectionFunctionTableOffset = src->instanceContributionToHitGroupIndex;
dstMetal->accelerationStructureIndex = src->accelerationStructure.value;
dstMetal->userID = src->instanceID;
break;
}
}
}

inline void convertAccelerationStructureInstanceDescs(
GfxCount count,
AccelerationStructureInstanceDescType dstType,
void* dst,
Size dstStride,
const AccelerationStructureInstanceDescGeneric* src,
Size srcStride
)
{
for (GfxIndex i = 0; i < count; ++i)
{
convertAccelerationStructureInstanceDesc(dstType, dst, src);
dst = (uint8_t*)dst + dstStride;
src = (const AccelerationStructureInstanceDescGeneric*)((const uint8_t*)src + srcStride);
}
}

} // namespace rhi
30 changes: 30 additions & 0 deletions src/metal/metal-acceleration-structure.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#include "metal-acceleration-structure.h"
#include "metal-device.h"

namespace rhi::metal {

AccelerationStructureImpl::~AccelerationStructureImpl()
{
m_device->m_accelerationStructures.freeList.push_back(m_globalIndex);
m_device->m_accelerationStructures.list[m_globalIndex] = nullptr;
m_device->m_accelerationStructures.dirty = true;
}

Result AccelerationStructureImpl::getNativeHandle(NativeHandle* outHandle)
{
outHandle->type = NativeHandleType::MTLAccelerationStructure;
outHandle->value = (uint64_t)m_accelerationStructure.get();
return SLANG_OK;
}

AccelerationStructureHandle AccelerationStructureImpl::getHandle()
{
return AccelerationStructureHandle{m_globalIndex};
}

DeviceAddress AccelerationStructureImpl::getDeviceAddress()
{
return 0;
}

} // namespace rhi::metal
29 changes: 29 additions & 0 deletions src/metal/metal-acceleration-structure.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include "metal-base.h"

namespace rhi::metal {

class AccelerationStructureImpl : public AccelerationStructure
{
public:
DeviceImpl* m_device;
NS::SharedPtr<MTL::AccelerationStructure> m_accelerationStructure;
uint32_t m_globalIndex;

public:
AccelerationStructureImpl(DeviceImpl* device, const AccelerationStructureDesc& desc)
: AccelerationStructure(desc)
, m_device(device)
{
}

~AccelerationStructureImpl();

// IAccelerationStructure implementation
virtual SLANG_NO_THROW Result SLANG_MCALL getNativeHandle(NativeHandle* outHandle) override;
virtual SLANG_NO_THROW AccelerationStructureHandle getHandle() override;
virtual SLANG_NO_THROW DeviceAddress SLANG_MCALL getDeviceAddress() override;
};

} // namespace rhi::metal
16 changes: 16 additions & 0 deletions src/metal/metal-command-buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ MTL::ComputeCommandEncoder* CommandBufferImpl::getMetalComputeCommandEncoder()
return m_metalComputeCommandEncoder.get();
}

MTL::AccelerationStructureCommandEncoder* CommandBufferImpl::getMetalAccelerationStructureCommandEncoder()
{
if (!m_metalAccelerationStructureCommandEncoder)
{
endMetalCommandEncoder();
m_metalAccelerationStructureCommandEncoder =
NS::RetainPtr(m_commandBuffer->accelerationStructureCommandEncoder());
}
return m_metalAccelerationStructureCommandEncoder.get();
}

MTL::BlitCommandEncoder* CommandBufferImpl::getMetalBlitCommandEncoder()
{
if (!m_metalBlitCommandEncoder)
Expand All @@ -103,6 +114,11 @@ void CommandBufferImpl::endMetalCommandEncoder()
m_metalComputeCommandEncoder->endEncoding();
m_metalComputeCommandEncoder.reset();
}
if (m_metalAccelerationStructureCommandEncoder)
{
m_metalAccelerationStructureCommandEncoder->endEncoding();
m_metalAccelerationStructureCommandEncoder.reset();
}
if (m_metalBlitCommandEncoder)
{
m_metalBlitCommandEncoder->endEncoding();
Expand Down
2 changes: 2 additions & 0 deletions src/metal/metal-command-buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class CommandBufferImpl : public ICommandBuffer, public ComObject

NS::SharedPtr<MTL::RenderCommandEncoder> m_metalRenderCommandEncoder;
NS::SharedPtr<MTL::ComputeCommandEncoder> m_metalComputeCommandEncoder;
NS::SharedPtr<MTL::AccelerationStructureCommandEncoder> m_metalAccelerationStructureCommandEncoder;
NS::SharedPtr<MTL::BlitCommandEncoder> m_metalBlitCommandEncoder;

// Command buffers are deallocated by its command pool,
Expand All @@ -38,6 +39,7 @@ class CommandBufferImpl : public ICommandBuffer, public ComObject

MTL::RenderCommandEncoder* getMetalRenderCommandEncoder(MTL::RenderPassDescriptor* renderPassDesc);
MTL::ComputeCommandEncoder* getMetalComputeCommandEncoder();
MTL::AccelerationStructureCommandEncoder* getMetalAccelerationStructureCommandEncoder();
MTL::BlitCommandEncoder* getMetalBlitCommandEncoder();
void endMetalCommandEncoder();

Expand Down
45 changes: 45 additions & 0 deletions src/metal/metal-command-encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "metal-texture.h"
#include "metal-texture-view.h"
#include "metal-util.h"
#include "metal-acceleration-structure.h"

namespace rhi::metal {

Expand Down Expand Up @@ -586,6 +587,33 @@ void RayTracingCommandEncoderImpl::buildAccelerationStructure(
AccelerationStructureQueryDesc* queryDescs
)
{
MTL::AccelerationStructureCommandEncoder* encoder = m_commandBuffer->getMetalAccelerationStructureCommandEncoder();

AccelerationStructureDescBuilder builder;
builder.build(desc, m_commandBuffer->m_device->getAccelerationStructureArray(), getDebugCallback());

switch (desc.mode)
{
case AccelerationStructureBuildMode::Build:
encoder->buildAccelerationStructure(
static_cast<AccelerationStructureImpl*>(dst)->m_accelerationStructure.get(),
builder.descriptor.get(),
static_cast<BufferImpl*>(scratchBuffer.buffer)->m_buffer.get(),
scratchBuffer.offset
);
break;
case AccelerationStructureBuildMode::Update:
encoder->refitAccelerationStructure(
static_cast<AccelerationStructureImpl*>(src)->m_accelerationStructure.get(),
builder.descriptor.get(),
static_cast<AccelerationStructureImpl*>(dst)->m_accelerationStructure.get(),
static_cast<BufferImpl*>(scratchBuffer.buffer)->m_buffer.get(),
scratchBuffer.offset
);
break;
}

// TODO handle queryDescs
}

void RayTracingCommandEncoderImpl::copyAccelerationStructure(
Expand All @@ -594,6 +622,23 @@ void RayTracingCommandEncoderImpl::copyAccelerationStructure(
AccelerationStructureCopyMode mode
)
{
MTL::AccelerationStructureCommandEncoder* encoder = m_commandBuffer->getMetalAccelerationStructureCommandEncoder();

switch (mode)
{
case AccelerationStructureCopyMode::Clone:
encoder->copyAccelerationStructure(
static_cast<AccelerationStructureImpl*>(src)->m_accelerationStructure.get(),
static_cast<AccelerationStructureImpl*>(dst)->m_accelerationStructure.get()
);
break;
case AccelerationStructureCopyMode::Compact:
encoder->copyAndCompactAccelerationStructure(
static_cast<AccelerationStructureImpl*>(src)->m_accelerationStructure.get(),
static_cast<AccelerationStructureImpl*>(dst)->m_accelerationStructure.get()
);
break;
}
}

void RayTracingCommandEncoderImpl::queryAccelerationStructureProperties(
Expand Down
Loading

0 comments on commit e64b712

Please sign in to comment.