diff --git a/lib/kernels/src/cuda/ops/reverse_kernels.cu b/lib/kernels/src/cuda/ops/reverse_kernels.cu index 8e93fec0d6..2c25293c36 100644 --- a/lib/kernels/src/cuda/ops/reverse_kernels.cu +++ b/lib/kernels/src/cuda/ops/reverse_kernels.cu @@ -17,44 +17,9 @@ #include "kernels/reverse_kernels.h" namespace FlexFlow { - namespace Kernels { namespace Reverse { -// __global__ void reverse_forward_kernel(float const *in_ptr, -// float *out_ptr, -// coord_t num_out_blks, -// coord_t reverse_dim_size, -// coord_t in_blk_size) { -// CUDA_KERNEL_LOOP(i, num_out_blks * reverse_dim_size * in_blk_size) { -// coord_t out_idx = i; -// coord_t blk_idx = i / (reverse_dim_size * in_blk_size); -// i = i - blk_idx * (reverse_dim_size * in_blk_size); -// coord_t reverse_dim_idx = i / in_blk_size; -// i = i - reverse_dim_idx * in_blk_size; -// coord_t in_idx = blk_idx * (reverse_dim_size * in_blk_size) + -// (reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + -// i; -// out_ptr[out_idx] = in_ptr[in_idx]; -// } -// CUDA_KERNEL_LOOP(i, num_out_blks * reverse_dim_size * in_blk_size) { -// coord_t blk_idx = i / (reverse_dim_size * in_blk_size); -// i = i - blk_idx * (reverse_dim_size * in_blk_size); -// coord_t reverse_dim_idx = i / in_blk_size; -// i = i - reverse_dim_idx * in_blk_size; -// coord_t in_idx = blk_idx * (reverse_dim_size * in_blk_size) + -// (reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + -// i; -// out_ptr[i] = in_ptr[in_idx]; -// } -// } - -/* I mentioned this earlier, but I still think the reverse_forward_kernel code - is incorrect, even though it matches the code in inference/master? Whenever - I'm testing the code and printing out the output, I'm getting unexpected - outputs, and I think it's a result of modifying the loop index i in the - previous code? -*/ __global__ void reverse_forward_kernel(float const *in_ptr, float *out_ptr, coord_t num_out_blks, @@ -62,13 +27,12 @@ __global__ void reverse_forward_kernel(float const *in_ptr, coord_t in_blk_size) { CUDA_KERNEL_LOOP(i, num_out_blks * reverse_dim_size * in_blk_size) { coord_t blk_idx = i / (reverse_dim_size * in_blk_size); - coord_t idx_within_blk = i % (reverse_dim_size * in_blk_size); - coord_t reverse_dim_idx = idx_within_blk / in_blk_size; - coord_t in_idx = idx_within_blk % in_blk_size; - coord_t input_index = - blk_idx * (reverse_dim_size * in_blk_size) + - (reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + in_idx; - out_ptr[i] = in_ptr[input_index]; + i = i - blk_idx * (reverse_dim_size * in_blk_size); + coord_t reverse_dim_idx = i / in_blk_size; + i = i - reverse_dim_idx * in_blk_size; + coord_t in_idx = blk_idx * (reverse_dim_size * in_blk_size) + + (reverse_dim_size - 1 - reverse_dim_idx) * in_blk_size + i; + out_ptr[i] = in_ptr[in_idx]; } }