Skip to content

Commit

Permalink
Merge pull request #125 from Themaister/meshlet-cull
Browse files Browse the repository at this point in the history
Implement basic mesh culling system in tests
  • Loading branch information
Themaister authored Dec 15, 2023
2 parents 2b04fae + bb9239c commit 67f573f
Show file tree
Hide file tree
Showing 14 changed files with 976 additions and 245 deletions.
376 changes: 296 additions & 80 deletions assets/shaders/inc/meshlet_payload_decode.h

Large diffs are not rendered by default.

116 changes: 116 additions & 0 deletions assets/shaders/inc/meshlet_render.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#ifndef MESHLET_RENDER_H_
#define MESHLET_RENDER_H_

#ifndef MESHLET_RENDER_DESCRIPTOR_SET
#error "Must define MESHLET_RENDER_DESCRIPTOR_SET before including meshlet_render.h"
#endif

#ifndef MESHLET_RENDER_AABB_BINDING
#error "Must define MESHLET_RENDER_AABB_BINDING before including meshlet_render.h"
#endif

#ifndef MESHLET_RENDER_TRANSFORM_BINDING
#error "Must define MESHLET_RENDER_TRANSFORM_BINDING before including meshlet_render.h"
#endif

#ifndef MESHLET_RENDER_BOUND_BINDING
#error "Must define MESHLET_RENDER_BOUND_BINDING before including meshlet_render.h"
#endif

#ifndef MESHLET_RENDER_FRUSTUM_BINDING
#error "Must define MESHLET_RENDER_GROUP_BOUND_BINDING before including meshlet_render.h"
#endif

#ifndef MESHLET_RENDER_TASKS_BINDING
#error "Must define MESHLET_RENDER_TASKS_BINDING before including meshlet_render.h"
#endif

struct AABB
{
vec3 lo; float pad0; vec3 hi; float pad;
};

struct Bound
{
vec4 center_radius;
vec4 cone;
};

layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_BOUND_BINDING, std430) readonly buffer Bounds
{
Bound data[];
} bounds;

layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_AABB_BINDING, std430) readonly buffer AABBSSBO
{
AABB data[];
} aabb;

layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_TRANSFORM_BINDING, std430) readonly buffer Transforms
{
mat4 data[];
} transforms;

layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_FRUSTUM_BINDING, std140) uniform Frustum
{
vec4 planes[6];
} frustum;

struct TaskInfo
{
uint aabb_instance;
uint node_instance;
uint node_count_material_index; // Skinning
uint mesh_index_count;
};

layout(set = MESHLET_RENDER_DESCRIPTOR_SET, binding = MESHLET_RENDER_TASKS_BINDING, std430) readonly buffer Tasks
{
TaskInfo data[];
} task_info;

bool frustum_cull(vec3 lo, vec3 hi)
{
bool ret = true;
for (int i = 0; i < 6 && ret; i++)
{
vec4 p = frustum.planes[i];
bvec3 high_mask = greaterThan(p.xyz, vec3(0.0));
vec3 max_coord = mix(lo, hi, high_mask);
if (dot(vec4(max_coord, 1.0), p) < 0.0)
ret = false;
}
return ret;
}

bool cluster_cull(mat4 M, Bound bound, vec3 camera_pos)
{
vec3 bound_center = (M * vec4(bound.center_radius.xyz, 1.0)).xyz;

float s0 = dot(M[0].xyz, M[0].xyz);
float s1 = dot(M[1].xyz, M[1].xyz);
float s2 = dot(M[2].xyz, M[2].xyz);

float max_scale_factor = sqrt(max(max(s0, s1), s2));
float effective_radius = bound.center_radius.w * max_scale_factor;

// Cluster cone cull.
bool ret = true;

vec4 cone = bound.cone;
if (cone.w < 1.0)
{
cone = vec4(normalize(mat3(M) * cone.xyz), cone.w);
ret = dot(bound_center - camera_pos, cone.xyz) <= cone.w * length(bound_center - camera_pos) + effective_radius;
}

for (int i = 0; i < 6 && ret; i++)
{
vec4 p = frustum.planes[i];
if (dot(vec4(bound_center, 1.0), p) < -effective_radius)
ret = false;
}
return ret;
}

#endif
45 changes: 34 additions & 11 deletions scene-export/meshlet_export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,9 +342,9 @@ static void find_linear_predictor(uint16_t *predictor,
predictor[4 + i] = uint16_t(b[i]);
}

static void encode_stream(std::vector <uint32_t> &out_payload_buffer,
Stream &stream, u8vec4 (&stream_buffer)[MaxElements],
unsigned num_elements)
static size_t encode_stream(std::vector <uint32_t> &out_payload_buffer,
Stream &stream, u8vec4 (&stream_buffer)[MaxElements],
unsigned num_elements)
{
stream.offset_from_base_u32 = uint32_t(out_payload_buffer.size());

Expand Down Expand Up @@ -430,6 +430,8 @@ static void encode_stream(std::vector <uint32_t> &out_payload_buffer,
for (unsigned i = 0; i < required_bits.w; i++)
out_payload_buffer.push_back(extract_bit_plane(&stream_buffer[chunk_index * 32][3], i));
}

return out_payload_buffer.size() - stream.offset_from_base_u32;
}

static void encode_mesh(Encoded &encoded,
Expand All @@ -446,6 +448,7 @@ static void encode_mesh(Encoded &encoded,

std::unordered_map <uint32_t, uint32_t> vbo_remap;
uint32_t primitive_index = 0;
size_t words_per_stream[MaxU32Streams] = {};

for (uint32_t meshlet_index = 0; meshlet_index < num_meshlets; meshlet_index++)
{
Expand Down Expand Up @@ -481,7 +484,8 @@ static void encode_mesh(Encoded &encoded,
stream_buffer[i] = u8vec4(i0, i1, i2, 0);
}

encode_stream(encoded.payload, meshlet.u32_streams[0], stream_buffer, analysis_result.num_primitives);
words_per_stream[0] +=
encode_stream(encoded.payload, meshlet.u32_streams[0], stream_buffer, analysis_result.num_primitives);

// Handle spill region just in case.
uint64_t vbo_remapping[MaxVertices + 3];
Expand All @@ -502,14 +506,19 @@ static void encode_mesh(Encoded &encoded,
memcpy(stream_buffer[i].data, &payload, sizeof(payload));
}

encode_stream(encoded.payload, meshlet.u32_streams[stream_index + 1], stream_buffer,
analysis_result.num_vertices);
words_per_stream[stream_index + 1] +=
encode_stream(encoded.payload, meshlet.u32_streams[stream_index + 1], stream_buffer,
analysis_result.num_vertices);
}

mesh.meshlets.push_back(meshlet);
base_vertex_offset += analysis_result.num_vertices;
primitive_index += primitives_to_process;
}

for (unsigned i = 0; i < MaxU32Streams; i++)
if (words_per_stream[i])
LOGI("Stream[%u] = %zu bytes.\n", i, words_per_stream[i] * sizeof(uint32_t));
}

static bool export_encoded_mesh(const std::string &path, const Encoded &encoded)
Expand Down Expand Up @@ -675,7 +684,7 @@ bool export_mesh_to_meshlet(const std::string &path, SceneFormats::Mesh mesh, Me
out_vertex_redirection_buffer.data(), local_index_buffer.data(),
reinterpret_cast<const uint32_t *>(mesh.indices.data()), mesh.count,
position_buffer[0].data, positions.size(), sizeof(vec3),
max_vertices, max_primitives, 0.75f);
max_vertices, max_primitives, 0.5f);

meshlets.resize(num_meshlets);

Expand Down Expand Up @@ -703,7 +712,7 @@ bool export_mesh_to_meshlet(const std::string &path, SceneFormats::Mesh mesh, Me
std::vector<meshopt_Bounds> bounds;
bounds.clear();
bounds.reserve(num_meshlets);
for (auto &meshlet: out_meshlets)
for (auto &meshlet : out_meshlets)
{
auto bound = meshopt_computeClusterBounds(
out_index_buffer[meshlet.offset].data, meshlet.count * 3,
Expand All @@ -719,14 +728,28 @@ bool export_mesh_to_meshlet(const std::string &path, SceneFormats::Mesh mesh, Me

assert(bounds.size() == encoded.mesh.meshlets.size());
const auto *pbounds = bounds.data();
for (auto &meshlet: encoded.mesh.meshlets)
for (auto &meshlet : encoded.mesh.meshlets)
{
memcpy(meshlet.bound.center, pbounds->center, sizeof(float) * 3);
meshlet.bound.radius = pbounds->radius;
memcpy(meshlet.bound.cone_axis_cutoff, pbounds->cone_axis_s8, sizeof(pbounds->cone_axis_s8));
meshlet.bound.cone_axis_cutoff[3] = pbounds->cone_cutoff_s8;
memcpy(meshlet.bound.cone_axis_cutoff, pbounds->cone_axis, sizeof(pbounds->cone_axis));
meshlet.bound.cone_axis_cutoff[3] = pbounds->cone_cutoff;
pbounds++;
}

LOGI("Exported meshlet:\n");
LOGI(" %zu meshlets\n", encoded.mesh.meshlets.size());
LOGI(" %zu payload bytes\n", encoded.payload.size() * sizeof(uint32_t));
LOGI(" %u total indices\n", mesh.count);
LOGI(" %zu total attributes\n", mesh.positions.size() / mesh.position_stride);

size_t uncompressed_bytes = mesh.indices.size();
uncompressed_bytes += mesh.positions.size();
if (style != MeshStyle::Wireframe)
uncompressed_bytes += mesh.attributes.size();

LOGI(" %zu uncompressed bytes\n\n\n", uncompressed_bytes);

return export_encoded_mesh(path, encoded);
}
}
Expand Down
138 changes: 108 additions & 30 deletions tests/assets/shaders/meshlet_cull.comp
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
#version 450
#extension GL_EXT_scalar_block_layout : require
#if MESHLET_PAYLOAD_SUBGROUP
#extension GL_KHR_shader_subgroup_ballot : require
#extension GL_KHR_shader_subgroup_vote : require
#extension GL_KHR_shader_subgroup_shuffle : require
#endif

layout(local_size_x = 32) in;

struct AABB
{
vec4 lo, hi;
};

layout(set = 0, binding = 0, std430) readonly buffer AABBSSBO
{
AABB data[];
} aabb;

layout(set = 0, binding = 1, std430) readonly buffer Transforms
{
mat4 data[];
} transforms;

layout(set = 0, binding = 2, std430) readonly buffer Tasks
{
uvec4 data[];
} task_info;
#define MESHLET_RENDER_DESCRIPTOR_SET 0
#define MESHLET_RENDER_BOUND_BINDING 6
#define MESHLET_RENDER_AABB_BINDING 0
#define MESHLET_RENDER_TRANSFORM_BINDING 1
#define MESHLET_RENDER_FRUSTUM_BINDING 7
#define MESHLET_RENDER_TASKS_BINDING 2
#include "meshlet_render.h"

struct Draw
{
Expand All @@ -39,31 +33,115 @@ layout(set = 0, binding = 4, std430) writeonly buffer OutputDraws
Draw data[];
} output_draws;

layout(set = 0, binding = 5, std430) writeonly buffer CompactedDraws
layout(set = 0, binding = 5, scalar) writeonly buffer CompactedDraws
{
uvec2 data[];
uvec3 data[];
} output_draw_info;

layout(push_constant, std430) uniform Registers
{
vec3 camera_pos;
uint count;
} registers;

#if !MESHLET_PAYLOAD_SUBGROUP
shared uint ballot_value;
shared uint global_offset;
uvec4 ballot(bool v)
{
barrier();
if (gl_LocalInvocationIndex == 0)
ballot_value = 0;
barrier();
if (v)
atomicOr(ballot_value, 1u << gl_LocalInvocationIndex);
barrier();
return uvec4(ballot_value, 0, 0, 0);
}

uint ballotBitCount(uvec4 v)
{
return bitCount(v.x);
}

uint ballotExclusiveBitCount(uvec4 v)
{
uint mask = (1u << gl_LocalInvocationIndex) - 1u;
return bitCount(v.x & mask);
}
#define local_invocation_id gl_LocalInvocationIndex
#else
#define ballot(v) subgroupBallot(v)
#define ballotBitCount(v) subgroupBallotBitCount(v)
#define ballotExclusiveBitCount(v) subgroupBallotExclusiveBitCount(v)
#define local_invocation_id gl_SubgroupInvocationID
#endif

void main()
{
uvec4 command_payload;
uint task_index = gl_GlobalInvocationID.x;
TaskInfo task;
uint task_index = gl_WorkGroupID.x * gl_WorkGroupSize.x + local_invocation_id;
bool task_needs_work = false;
if (task_index < registers.count)
{
command_payload = task_info.data[task_index];
uint offset = command_payload.w & ~31u;
uint count = bitfieldExtract(command_payload.w, 0, 5) + 1;
task = task_info.data[task_index];

// Precull the group.
AABB aabb = aabb.data[task.aabb_instance];
task_needs_work = frustum_cull(aabb.lo, aabb.hi);
}

uint b = ballot(task_needs_work).x;

while (b != 0)
{
int lane = findLSB(b);
b &= ~(1u << lane);

#if MESHLET_PAYLOAD_SUBGROUP
uint node_instance = subgroupShuffle(task.node_instance, lane);
uint node_count_material_index = subgroupShuffle(task.node_count_material_index, lane);
uint mesh_index_count = subgroupShuffle(task.mesh_index_count, lane);
#else
TaskInfo tmp_task = task_info.data[gl_WorkGroupID.x * gl_WorkGroupSize.x + lane];
uint node_instance = tmp_task.node_instance;
uint node_count_material_index = tmp_task.node_count_material_index;
uint mesh_index_count = tmp_task.mesh_index_count;
#endif

uint offset = mesh_index_count & ~31u;
uint count = bitfieldExtract(mesh_index_count, 0, 5) + 1;

uint meshlet_index = offset + local_invocation_id;

bool alloc_draw = false;
if (local_invocation_id < count)
{
mat4 M = transforms.data[node_instance];
Bound b = bounds.data[meshlet_index];
alloc_draw = cluster_cull(M, b, registers.camera_pos);
}

uvec4 ballot = ballot(alloc_draw);
uint draw_count = ballotBitCount(ballot);
uint local_offset = ballotExclusiveBitCount(ballot);

#if MESHLET_PAYLOAD_SUBGROUP
uint global_offset;
if (subgroupElect())
global_offset = atomicAdd(output_draws.count, draw_count);
global_offset = subgroupBroadcastFirst(global_offset);
#else
// WAR barrier is implied here in earlier ballot.
if (gl_LocalInvocationIndex == 0)
global_offset = atomicAdd(output_draws.count, draw_count);
barrier();
#endif

uint draw_offset = atomicAdd(output_draws.count, count);
for (uint i = 0; i < count; i++)
if (alloc_draw)
{
output_draws.data[draw_offset + i] = input_draws.data[offset + i];
output_draw_info.data[draw_offset + i] = command_payload.yz;
output_draws.data[global_offset + local_offset] = input_draws.data[meshlet_index];
output_draw_info.data[global_offset + local_offset] = uvec3(node_instance, node_count_material_index, meshlet_index);
}
}
}
Loading

0 comments on commit 67f573f

Please sign in to comment.