Skip to content

Commit

Permalink
Rewrite HiZ shader.
Browse files Browse the repository at this point in the history
It was broken for some NPOT patterns.
  • Loading branch information
Themaister committed Nov 24, 2024
1 parent b2dd359 commit 7eb1dd7
Show file tree
Hide file tree
Showing 2 changed files with 180 additions and 57 deletions.
227 changes: 175 additions & 52 deletions assets/shaders/post/hiz.comp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void write_image4(ivec2 coord, int mip, vec4 v)
imageStore(uImages[mip], coord + ivec2(1, 1), v.wwww);
}

shared float shared_buffer[256 / 16];
shared float shared_buffer[128];
shared bool shared_is_last_workgroup;

mat4 fetch_4x4_texture(ivec2 base_coord)
Expand All @@ -96,13 +96,10 @@ vec4 fetch_2x2_image_mip6(ivec2 base_coord)
return vec4(d0, d1, d2, d3);
}

mat4 fetch_4x4_image_mip6(ivec2 base_coord)
float fetch_image_mip6(ivec2 coord)
{
vec4 q0 = fetch_2x2_image_mip6(base_coord + ivec2(0, 0));
vec4 q1 = fetch_2x2_image_mip6(base_coord + ivec2(2, 0));
vec4 q2 = fetch_2x2_image_mip6(base_coord + ivec2(0, 2));
vec4 q3 = fetch_2x2_image_mip6(base_coord + ivec2(2, 2));
return mat4(q0, q1, q2, q3);
ivec2 max_coord = mip_resolution(6) - 1;
return imageLoad(uImages[6], min(coord, max_coord)).x;
}

mat4 write_mip0_transformed(mat4 M, ivec2 base_coord)
Expand All @@ -121,7 +118,11 @@ mat4 write_mip0_transformed(mat4 M, ivec2 base_coord)
return mat4(q00, q10, q01, q11);
}

float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip, bool full_res_pass)
// For LOD 0 to 6, it is expected that the division is exact,
// i.e., the lower resolution mip is exactly half resolution.
// This way we avoid needing to fold in neighbors.

float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip)
{
vec4 q00 = M[0];
vec4 q10 = M[1];
Expand All @@ -135,49 +136,165 @@ float reduce_mip_registers(mat4 M, ivec2 base_coord, int mip, bool full_res_pass
float d01 = reduce(q01);
float d11 = reduce(q11);

if (!full_res_pass)
q00 = vec4(d00, d10, d01, d11);
write_image4(base_coord, mip, q00);

return reduce(q00);
}

void reduce_mip_shared(ivec2 base_coord, int mip)
{
ivec2 mip_res_higher = mip_resolution(mip - 1);
ivec2 mip_res_target = mip_resolution(mip);

bool horiz_fold = base_coord.x + 1 == mip_res_target.x && (mip_res_higher.x & 1) != 0;
bool vert_fold = base_coord.y + 1 == mip_res_target.y && (mip_res_higher.y & 1) != 0;
bool diag_fold = horiz_fold && vert_fold;

// Ping-pong the shared buffer to avoid double barrier.
uint out_offset = (mip & 1) * 64;
uint in_offset = 64 - out_offset;

float d00 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2];
float d10 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 1];
float d01 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 16];
float d11 = shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 17];

float reduced = reduce(vec4(d00, d10, d01, d11));
if (horiz_fold)
{
if (base_coord.x + 1 == mip_res.x) // LOD math chops off data. Need to fold border values into the reduction.
{
d00 = REDUCE_OPERATOR(d00, d10);
d01 = REDUCE_OPERATOR(d01, d11);
}
reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 2]);
reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 18]);
}

if (base_coord.y + 1 == mip_res.y)
{
d01 = REDUCE_OPERATOR(d01, d00);
d11 = REDUCE_OPERATOR(d11, d10);
}
if (vert_fold)
{
reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 32]);
reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 33]);
}

q00 = vec4(d00, d10, d01, d11);
write_image4(base_coord, mip, q00);
if (diag_fold)
reduced = REDUCE_OPERATOR(reduced, shared_buffer[in_offset + base_coord.y * 32 + base_coord.x * 2 + 34]);

return reduce(q00);
shared_buffer[out_offset + base_coord.y * 16 + base_coord.x] = reduced;
write_image(base_coord, mip, reduced);
}

float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d, bool full_res_pass)
void reduce_mip_lod7_lod8(ivec2 base_coord_lod8)
{
ivec2 mip_res6 = mip_resolution(6);
ivec2 mip_res7 = mip_resolution(7);
ivec2 base_coord_lod6 = base_coord_lod8 * 4;
ivec2 base_coord_lod7 = base_coord_lod8 * 2;

float d00 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(0, 0)));
float d10 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(2, 0)));
float d01 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(0, 2)));
float d11 = reduce(fetch_2x2_image_mip6(base_coord_lod6 + ivec2(2, 2)));

// NPOT folding for LOD 7. Our group will write the edge,
// so need to fold in any last neighbor in previous LOD which may contribute,
// but would be otherwise lost to the rounding down.

bool horiz_fold = base_coord_lod7.x + 2 == mip_res7.x && (mip_res6.x & 1) != 0;
bool vert_fold = base_coord_lod7.y + 2 == mip_res7.y && (mip_res6.y & 1) != 0;
bool diag_fold = horiz_fold && vert_fold;

if (horiz_fold)
{
float d20 = REDUCE_OPERATOR(
fetch_image_mip6(base_coord_lod6 + ivec2(4, 0)),
fetch_image_mip6(base_coord_lod6 + ivec2(4, 1)));

float d21 = REDUCE_OPERATOR(
fetch_image_mip6(base_coord_lod6 + ivec2(4, 2)),
fetch_image_mip6(base_coord_lod6 + ivec2(4, 3)));

d10 = REDUCE_OPERATOR(d10, d20);
d11 = REDUCE_OPERATOR(d11, d21);
}

if (vert_fold)
{
float d02 = REDUCE_OPERATOR(
fetch_image_mip6(base_coord_lod6 + ivec2(0, 4)),
fetch_image_mip6(base_coord_lod6 + ivec2(1, 4)));

float d12 = REDUCE_OPERATOR(
fetch_image_mip6(base_coord_lod6 + ivec2(2, 4)),
fetch_image_mip6(base_coord_lod6 + ivec2(3, 4)));

d01 = REDUCE_OPERATOR(d01, d02);
d11 = REDUCE_OPERATOR(d11, d12);
}

if (diag_fold)
{
float d22 = fetch_image_mip6(base_coord_lod6 + ivec2(4, 4));
d11 = REDUCE_OPERATOR(d11, d22);
}

// If the edge pixels will be dropped, fold them into the top-left pixel.
horiz_fold = base_coord_lod7.x + 1 == mip_res7.x;
vert_fold = base_coord_lod7.y + 1 == mip_res7.y;
diag_fold = horiz_fold && vert_fold;

if (horiz_fold)
{
d00 = REDUCE_OPERATOR(d00, d10);
d01 = REDUCE_OPERATOR(d01, d11);
}

if (vert_fold)
{
d00 = REDUCE_OPERATOR(d00, d01);
d10 = REDUCE_OPERATOR(d10, d11);
}

if (diag_fold)
d00 = REDUCE_OPERATOR(d00, d11);

vec4 quad = vec4(d00, d10, d01, d11);
write_image4(base_coord_lod7, 7, quad);

if (registers.mips > 8)
{
float lod8 = reduce(quad);
shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x] = lod8;

// If writes to mip 8 may be sliced, fixup.
if (((mip_res7.x | mip_res7.y) & 1) != 0)
{
barrier();

ivec2 mip_res8 = mip_resolution(8);
horiz_fold = base_coord_lod8.x + 1 == mip_res8.x && (mip_res7.x & 1) != 0;
vert_fold = base_coord_lod8.y + 1 == mip_res8.y && (mip_res7.y & 1) != 0;
diag_fold = horiz_fold && vert_fold;

if (horiz_fold)
lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 1]);
if (vert_fold)
lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 16]);
if (diag_fold)
lod8 = REDUCE_OPERATOR(lod8, shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x + 17]);
if (horiz_fold || vert_fold)
shared_buffer[base_coord_lod8.y * 16 + base_coord_lod8.x] = lod8;
}

write_image(base_coord_lod8, 8, lod8);
}
}

float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d)
{
ivec2 mip_res = mip_resolution(mip);
float d_horiz, d_vert, d_diag;
bool swap_horiz, swap_vert;

// It is possible that our thread is barely in range, but horiz/vert neighbor is not.
#define CUTOFF_REDUCE() { \
swap_horiz = base_coord.x + 1 == mip_res.x; \
swap_vert = base_coord.y + 1 == mip_res.y; \
if (swap_horiz) \
d = REDUCE_OPERATOR(d, d_horiz); \
if (swap_vert) \
d = REDUCE_OPERATOR(d, d_vert); \
if (swap_vert && swap_horiz) \
d = REDUCE_OPERATOR(d, d_diag); }

d_horiz = subgroupQuadSwapHorizontal(d);
d_vert = subgroupQuadSwapVertical(d);
d_diag = subgroupQuadSwapDiagonal(d);
if (!full_res_pass)
CUTOFF_REDUCE();
write_image(base_coord, mip, d);

if (registers.mips > mip + 1)
Expand All @@ -190,8 +307,6 @@ float reduce_mips_simd16(ivec2 base_coord, uint local_index, int mip, float d, b
d_horiz = subgroupShuffleXor(d, SHUFFLE_X1);
d_vert = subgroupShuffleXor(d, SHUFFLE_Y1);
d_diag = subgroupShuffleXor(d, SHUFFLE_X1 | SHUFFLE_Y1);
if (!full_res_pass)
CUTOFF_REDUCE();
if ((local_index & 3) == 0)
write_image(base_coord, mip + 1, d);
}
Expand All @@ -215,12 +330,12 @@ void main()
// Write LOD 1, Compute LOD 2
if (registers.mips <= 1)
return;
float d = reduce_mip_registers(M, base_coord >> 1, 1, true);
float d = reduce_mip_registers(M, base_coord >> 1, 1);
if (registers.mips <= 2)
return;

// Write LOD 2, Compute LOD 3-4
d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d, true);
d = reduce_mips_simd16(base_coord >> 2, local_index, 2, d);
if (registers.mips <= 4)
return;

Expand All @@ -231,7 +346,7 @@ void main()

// Write LOD 4, Compute LOD 5-6.
if (local_index < 16)
d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index], true);
d = reduce_mips_simd16(ivec2(gl_WorkGroupID.xy * 4u + local_coord), local_index, 4, shared_buffer[local_index]);

// Write LOD 6.
if (registers.mips <= 6)
Expand All @@ -253,24 +368,32 @@ void main()
if (local_index == 0)
atomic_counter = 0u;

// Write LOD 7, Compute LOD 8
base_coord = ivec2(local_coord) * 4;
d = reduce_mip_registers(fetch_4x4_image_mip6(base_coord), base_coord >> 1, 7, false);
if (registers.mips <= 8)
// At this point, the mip resolutions may be non-POT and things get spicy.

// Write LOD 7-8, Compute LOD 8
reduce_mip_lod7_lod8(ivec2(local_coord));

if (registers.mips <= 9)
return;
barrier();
if (local_index < 64)
reduce_mip_shared(ivec2(local_coord), 9);

// Write LOD 8-9, Compute LOD 10
d = reduce_mips_simd16(ivec2(local_coord), local_index, 8, d, false);
if (registers.mips <= 10)
return;
if ((local_index & 15) == 0)
shared_buffer[local_index >> 4] = d;
barrier();

if (local_index < 16)
d = reduce_mips_simd16(ivec2(local_coord), local_index, 10, shared_buffer[local_index], false);
reduce_mip_shared(ivec2(local_coord), 10);

if (registers.mips <= 11)
return;
barrier();
if (local_index < 4)
reduce_mip_shared(ivec2(local_coord), 11);

if (registers.mips <= 12)
return;
barrier();
if (local_index == 0)
write_image(ivec2(0), 12, d);
reduce_mip_shared(ivec2(0), 12);
}
10 changes: 5 additions & 5 deletions tests/hiz.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,18 @@ int main()
Device dev;
dev.set_context(ctx);

constexpr unsigned WIDTH = 7;
constexpr unsigned HEIGHT = 3;
constexpr unsigned WIDTH = 7 * 64;
constexpr unsigned HEIGHT = 7 * 64;

float values[HEIGHT][WIDTH];
std::vector<float> values(WIDTH * HEIGHT);
for (unsigned y = 0; y < HEIGHT; y++)
for (unsigned x = 0; x < WIDTH; x++)
values[y][x] = float(y * 100000 + x);
values[y * WIDTH + x] = float(x + y);

auto info = ImageCreateInfo::immutable_2d_image(WIDTH, HEIGHT, VK_FORMAT_R32_SFLOAT);
info.usage = VK_IMAGE_USAGE_SAMPLED_BIT;
ImageInitialData init = {};
init.data = values;
init.data = values.data();
auto img = dev.create_image(info, &init);

info.usage = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_SAMPLED_BIT;
Expand Down

0 comments on commit 7eb1dd7

Please sign in to comment.