diff --git a/assets/shaders/post/hiz.comp b/assets/shaders/post/hiz.comp index 42b59ca9..d580be59 100644 --- a/assets/shaders/post/hiz.comp +++ b/assets/shaders/post/hiz.comp @@ -73,7 +73,7 @@ void write_image4(ivec2 coord, int mip, vec4 v) imageStore(uImages[mip], coord + ivec2(1, 1), v.wwww); } -shared float shared_buffer[256 / 16]; +shared float shared_buffer[128]; shared bool shared_is_last_workgroup; mat4 fetch_4x4_texture(ivec2 base_coord) @@ -96,13 +96,10 @@ vec4 fetch_2x2_image_mip6(ivec2 base_coord) return vec4(d0, d1, d2, d3); } -mat4 fetch_4x4_image_mip6(ivec2 base_coord) +float fetch_image_mip6(ivec2 coord) { - vec4 q0 = fetch_2x2_image_mip6(base_coord + ivec2(0, 0)); - vec4 q1 = fetch_2x2_image_mip6(base_coord + ivec2(2, 0)); - vec4 q2 = fetch_2x2_image_mip6(base_coord + ivec2(0, 2)); - vec4 q3 = fetch_2x2_image_mip6(base_coord + ivec2(2, 2)); - return mat4(q0, q1, q2, q3); + ivec2 max_coord = mip_resolution(6) - 1; + return imageLoad(uImages[6], min(coord, max_coord)).x; } mat4 write_mip0_transformed(mat4 M, ivec2 base_coord) @@ -121,7 +118,11 @@ mat4 write_mip0_transformed(mat4 M, ivec2 base_coord) return mat4(q00, q10, q01, q11); } -float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip, bool full_res_pass) +// For LOD 0 to 6, it is expected that the division is exact, +// i.e., the lower resolution mip is exactly half resolution. +// This way we avoid needing to fold in neighbors. + +float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip) { vec4 q00 = M[0]; vec4 q10 = M[1]; @@ -135,49 +136,165 @@ float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip, bool full_res_pass float d01 = reduce(q01); float d11 = reduce(q11); - if (!full_res_pass) + q00 = vec4(d00, d10, d01, d11); + write_image4(base_coord, mip, q00); + + return reduce(q00); +} + +void reduce_mip_shared(ivec2 base_coord, int mip) +{ + ivec2 mip_res_higher = mip_resolution(mip - 1); + ivec2 mip_res_target = mip_resolution(mip); + + bool horiz_fold = base_coord.x + 1 == mip_res_target.x && (mip_res_higher.x & 1) != 0; + bool vert_fold = base_coord.y + 1 == mip_res_target.y && (mip_res_higher.y & 1) != 0; + bool diag_fold = horiz_fold && vert_fold; + + // Ping-pong the shared buffer to avoid double barrier. + uint out_offset = (mip & 1) * 64; + uint in_offset = 64 - out_offset; + + float d00 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2]; + float d10 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 1]; + float d01 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 16]; + float d11 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 17]; + + float reduced = reduce(vec4(d00, d10, d01, d11)); + if (horiz_fold) { - if (base_coord.x + 1 == mip_res.x) // LOD math chops off data. Need to fold border values into the reduction. - { - d00 = REDUCE_OPERATOR(d00, d10); - d01 = REDUCE_OPERATOR(d01, d11); - } + reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 2]); + reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 18]); + } - if (base_coord.y + 1 == mip_res.y) - { - d01 = REDUCE_OPERATOR(d01, d00); - d11 = REDUCE_OPERATOR(d11, d10); - } + if (vert_fold) + { + reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 32]); + reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 33]); } - q00 = vec4(d00, d10, d01, d11); - write_image4(base_coord, mip, q00); + if (diag_fold) + reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 34]); - return reduce(q00); + shared_buffer[out_offset + base_coord.y * 16 + base_coord.x] = reduced; + write_image(base_coord, mip, reduced); } -float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d, bool full_res_pass) +void reduce_mip_lod7_lod8(ivec2 base_coord_lod8) +{ + ivec2 mip_res6 = mip_resolution(6); + ivec2 mip_res7 = mip_resolution(7); + ivec2 base_coord_lod6 = base_coord_lod8 * 4; + ivec2 base_coord_lod7 = base_coord_lod8 * 2; + + float d00 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(0, 0))); + float d10 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(2, 0))); + float d01 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(0, 2))); + float d11 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(2, 2))); + + // NPOT folding for LOD 7. Our group will write the edge, + // so need to fold in any last neighbor in previous LOD which may contribute, + // but would be otherwise lost to the rounding down. + + bool horiz_fold = base_coord_lod7.x + 2 == mip_res7.x && (mip_res6.x & 1) != 0; + bool vert_fold = base_coord_lod7.y + 2 == mip_res7.y && (mip_res6.y & 1) != 0; + bool diag_fold = horiz_fold && vert_fold; + + if (horiz_fold) + { + float d20 = REDUCE_OPERATOR( + fetch_image_mip6(base_coord_lod6 + ivec2(4, 0)), + fetch_image_mip6(base_coord_lod6 + ivec2(4, 1))); + + float d21 = REDUCE_OPERATOR( + fetch_image_mip6(base_coord_lod6 + ivec2(4, 2)), + fetch_image_mip6(base_coord_lod6 + ivec2(4, 3))); + + d10 = REDUCE_OPERATOR(d10, d20); + d11 = REDUCE_OPERATOR(d11, d21); + } + + if (vert_fold) + { + float d02 = REDUCE_OPERATOR( + fetch_image_mip6(base_coord_lod6 + ivec2(0, 4)), + fetch_image_mip6(base_coord_lod6 + ivec2(1, 4))); + + float d12 = REDUCE_OPERATOR( + fetch_image_mip6(base_coord_lod6 + ivec2(2, 4)), + fetch_image_mip6(base_coord_lod6 + ivec2(3, 4))); + + d01 = REDUCE_OPERATOR(d01, d02); + d11 = REDUCE_OPERATOR(d11, d12); + } + + if (diag_fold) + { + float d22 = fetch_image_mip6(base_coord_lod6 + ivec2(4, 4)); + d11 = REDUCE_OPERATOR(d11, d22); + } + + // If the edge pixels will be dropped, fold them into the top-left pixel. + horiz_fold = base_coord_lod7.x + 1 == mip_res7.x; + vert_fold = base_coord_lod7.y + 1 == mip_res7.y; + diag_fold = horiz_fold && vert_fold; + + if (horiz_fold) + { + d00 = REDUCE_OPERATOR(d00, d10); + d01 = REDUCE_OPERATOR(d01, d11); + } + + if (vert_fold) + { + d00 = REDUCE_OPERATOR(d00, d01); + d10 = REDUCE_OPERATOR(d10, d11); + } + + if (diag_fold) + d00 = REDUCE_OPERATOR(d00, d11); + + vec4 quad = vec4(d00, d10, d01, d11); + write_image4(base_coord_lod7, 7, quad); + + if (registers.mips > 8) + { + float lod8 = reduce(quad); + shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x] = lod8; + + // If writes to mip 8 may be sliced, fixup. + if (((mip_res7.x | mip_res7.y) & 1) != 0) + { + barrier(); + + ivec2 mip_res8 = mip_resolution(8); + horiz_fold = base_coord_lod8.x + 1 == mip_res8.x && (mip_res7.x & 1) != 0; + vert_fold = base_coord_lod8.y + 1 == mip_res8.y && (mip_res7.y & 1) != 0; + diag_fold = horiz_fold && vert_fold; + + if (horiz_fold) + lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 1]); + if (vert_fold) + lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 16]); + if (diag_fold) + lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 17]); + if (horiz_fold || vert_fold) + shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x] = lod8; + } + + write_image(base_coord_lod8, 8, lod8); + } +} + +float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d) { ivec2 mip_res = mip_resolution(mip); float d_horiz, d_vert, d_diag; bool swap_horiz, swap_vert; - // It is possible that our thread is barely in range, but horiz/vert neighbor is not. -#define CUTOFF_REDUCE() { \ - swap_horiz = base_coord.x + 1 == mip_res.x; \ - swap_vert = base_coord.y + 1 == mip_res.y; \ - if (swap_horiz) \ - d = REDUCE_OPERATOR(d, d_horiz); \ - if (swap_vert) \ - d = REDUCE_OPERATOR(d, d_vert); \ - if (swap_vert && swap_horiz) \ - d = REDUCE_OPERATOR(d, d_diag); } - d_horiz = subgroupQuadSwapHorizontal(d); d_vert = subgroupQuadSwapVertical(d); d_diag = subgroupQuadSwapDiagonal(d); - if (!full_res_pass) - CUTOFF_REDUCE(); write_image(base_coord, mip, d); if (registers.mips > mip + 1) @@ -190,8 +307,6 @@ float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d, b d_horiz = subgroupShuffleXor(d, SHUFFLE_X1); d_vert = subgroupShuffleXor(d, SHUFFLE_Y1); d_diag = subgroupShuffleXor(d, SHUFFLE_X1 | SHUFFLE_Y1); - if (!full_res_pass) - CUTOFF_REDUCE(); if ((local_index & 3) == 0) write_image(base_coord, mip + 1, d); } @@ -215,12 +330,12 @@ void main() // Write LOD 1, Compute LOD 2 if (registers.mips <= 1) return; - float d = reduce_mip_registers(M, base_coord >> 1, 1, true); + float d = reduce_mip_registers(M, base_coord >> 1, 1); if (registers.mips <= 2) return; // Write LOD 2, Compute LOD 3-4 - d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d, true); + d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d); if (registers.mips <= 4) return; @@ -231,7 +346,7 @@ void main() // Write LOD 4, Compute LOD 5-6. if (local_index < 16) - d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index], true); + d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index]); // Write LOD 6. if (registers.mips <= 6) @@ -253,24 +368,32 @@ void main() if (local_index == 0) atomic_counter = 0u; - // Write LOD 7, Compute LOD 8 - base_coord = ivec2(local_coord) * 4; - d = reduce_mip_registers(fetch_4x4_image_mip6(base_coord), base_coord >> 1, 7, false); - if (registers.mips <= 8) + // At this point, the mip resolutions may be non-POT and things get spicy. + + // Write LOD 7-8, Compute LOD 8 + reduce_mip_lod7_lod8(ivec2(local_coord)); + + if (registers.mips <= 9) return; + barrier(); + if (local_index < 64) + reduce_mip_shared(ivec2(local_coord), 9); - // Write LOD 8-9, Compute LOD 10 - d = reduce_mips_simd16(ivec2(local_coord), local_index, 8, d, false); if (registers.mips <= 10) return; - if ((local_index & 15) == 0) - shared_buffer[local_index >> 4] = d; barrier(); - if (local_index < 16) - d = reduce_mips_simd16(ivec2(local_coord), local_index, 10, shared_buffer[local_index], false); + reduce_mip_shared(ivec2(local_coord), 10); + + if (registers.mips <= 11) + return; + barrier(); + if (local_index < 4) + reduce_mip_shared(ivec2(local_coord), 11); + if (registers.mips <= 12) return; + barrier(); if (local_index == 0) - write_image(ivec2(0), 12, d); + reduce_mip_shared(ivec2(0), 12); } diff --git a/tests/hiz.cpp b/tests/hiz.cpp index ca5f308d..2c401873 100644 --- a/tests/hiz.cpp +++ b/tests/hiz.cpp @@ -40,18 +40,18 @@ int main() Device dev; dev.set_context(ctx); - constexpr unsigned WIDTH = 7; - constexpr unsigned HEIGHT = 3; + constexpr unsigned WIDTH = 7 * 64; + constexpr unsigned HEIGHT = 7 * 64; - float values[HEIGHT][WIDTH]; + std::vector values(WIDTH * HEIGHT); for (unsigned y = 0; y < HEIGHT; y++) for (unsigned x = 0; x < WIDTH; x++) - values[y][x] = float(y * 100000 + x); + values[y * WIDTH + x] = float(x + y); auto info = ImageCreateInfo::immutable_2d_image(WIDTH, HEIGHT, VK_FORMAT_R32_SFLOAT); info.usage = VK_IMAGE_USAGE_SAMPLED_BIT; ImageInitialData init = {}; - init.data = values; + init.data = values.data(); auto img = dev.create_image(info, &init); info.usage = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT;