diff --git a/keras/kokoro/github/ubuntu/gpu/build.sh b/keras/kokoro/github/ubuntu/gpu/build.sh index 5b523a3bc14..a1af4419e16 100644 --- a/keras/kokoro/github/ubuntu/gpu/build.sh +++ b/keras/kokoro/github/ubuntu/gpu/build.sh @@ -37,7 +37,6 @@ then --cov=keras fi -# TODO: Add test for JAX if [ "$KERAS_BACKEND" == "jax" ] then echo "JAX backend detected." @@ -59,7 +58,6 @@ then --cov=keras fi -# TODO: Add test for PyTorch if [ "$KERAS_BACKEND" == "torch" ] then echo "PyTorch backend detected." @@ -70,10 +68,8 @@ then python3 -c 'import torch;assert torch.cuda.is_available()' # TODO: Fix the failing Torch GPU CI tests. - # TODO: nn_test failures are on correctness tests. pytest keras --ignore keras/applications \ --ignore keras/layers/preprocessing/feature_space_test.py \ --ignore keras/layers/reshaping/flatten_test.py \ - --ignore keras/ops/nn_test.py \ --cov=keras fi diff --git a/keras/layers/convolutional/conv_transpose_test.py b/keras/layers/convolutional/conv_transpose_test.py index 62ac2321991..414efe83b50 100644 --- a/keras/layers/convolutional/conv_transpose_test.py +++ b/keras/layers/convolutional/conv_transpose_test.py @@ -16,6 +16,276 @@ ) +def np_conv1d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 1)) + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation = dilation_rate[0] + else: + h_dilation = dilation_rate + + h_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, _ = x.shape + # Get output shape and padding + _, h_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + + if h_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_kenel_size_tuple = (new_h_kernel,) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel = kernel_weights.shape[0] + + # Compute output + output = np.zeros([n_batch, h_out + h_kernel, ch_out]) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + output[nb, h_out_idx : h_out_idx + h_kernel, :] += np.sum( + kernel_weights[:, :, :] * x[nb, h_x_idx, :], axis=-1 + ) + output = output + bias_weights + + # Cut padding results from output + output = output[:, h_pad_side1 : h_out + h_pad_side1] + if data_format == "channels_first": + output = output.transpose((0, 2, 1)) + return output + + +def np_conv2d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride = strides + else: + h_stride = strides + w_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + + h_kernel, w_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, w_x, _ = x.shape + # Get output shape and padding + _, h_out, w_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + w_pad_side1 = w_kernel - 1 - jax_padding[1][0] + + if h_dilation > 1 or w_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel = kernel_weights.shape[:2] + + # Compute output + output = np.zeros([n_batch, h_out + h_kernel, w_out + w_kernel, ch_out]) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + for w_x_idx in range(w_x): + w_out_idx = w_x_idx * w_stride + output[ + nb, + h_out_idx : h_out_idx + h_kernel, + w_out_idx : w_out_idx + w_kernel, + :, + ] += np.sum( + kernel_weights[:, :, :, :] * x[nb, h_x_idx, w_x_idx, :], + axis=-1, + ) + output = output + bias_weights + + # Cut padding results from output + output = output[ + :, + h_pad_side1 : h_out + h_pad_side1, + w_pad_side1 : w_out + w_pad_side1, + ] + if data_format == "channels_first": + output = output.transpose((0, 3, 1, 2)) + return output + + +def np_conv3d_transpose( + x, + kernel_weights, + bias_weights, + strides, + padding, + output_padding, + data_format, + dilation_rate, +): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + if isinstance(strides, (tuple, list)): + h_stride, w_stride, d_stride = strides + else: + h_stride = strides + w_stride = strides + d_stride = strides + if isinstance(dilation_rate, (tuple, list)): + h_dilation, w_dilation, d_dilation = dilation_rate + else: + h_dilation = dilation_rate + w_dilation = dilation_rate + d_dilation = dilation_rate + + h_kernel, w_kernel, d_kernel, ch_out, ch_in = kernel_weights.shape + n_batch, h_x, w_x, d_x, _ = x.shape + # Get output shape and padding + _, h_out, w_out, d_out, _ = compute_conv_transpose_output_shape( + x.shape, + kernel_weights.shape, + ch_out, + strides, + padding, + output_padding, + data_format, + dilation_rate, + ) + jax_padding = compute_conv_transpose_padding_args_for_jax( + input_shape=x.shape, + kernel_shape=kernel_weights.shape, + strides=strides, + padding=padding, + output_padding=output_padding, + dilation_rate=dilation_rate, + ) + h_pad_side1 = h_kernel - 1 - jax_padding[0][0] + w_pad_side1 = w_kernel - 1 - jax_padding[1][0] + d_pad_side1 = d_kernel - 1 - jax_padding[2][0] + + if h_dilation > 1 or w_dilation > 1 or d_dilation > 1: + # Increase kernel size + new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) + new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) + new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) + new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) + new_kernel_weights = np.zeros( + (*new_kenel_size_tuple, ch_out, ch_in), + dtype=kernel_weights.dtype, + ) + new_kernel_weights[ + ::h_dilation, ::w_dilation, ::d_dilation + ] = kernel_weights + kernel_weights = new_kernel_weights + h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3] + + # Compute output + output = np.zeros( + [ + n_batch, + h_out + h_kernel, + w_out + w_kernel, + d_out + d_kernel, + ch_out, + ] + ) + for nb in range(n_batch): + for h_x_idx in range(h_x): + h_out_idx = h_x_idx * h_stride # Index in output + for w_x_idx in range(w_x): + w_out_idx = w_x_idx * w_stride + for d_x_idx in range(d_x): + d_out_idx = d_x_idx * d_stride + output[ + nb, + h_out_idx : h_out_idx + h_kernel, + w_out_idx : w_out_idx + w_kernel, + d_out_idx : d_out_idx + d_kernel, + :, + ] += np.sum( + kernel_weights[:, :, :, :, :] + * x[nb, h_x_idx, w_x_idx, d_x_idx, :], + axis=-1, + ) + output = output + bias_weights + + # Cut padding results from output + output = output[ + :, + h_pad_side1 : h_out + h_pad_side1, + w_pad_side1 : w_out + w_pad_side1, + d_pad_side1 : d_out + d_pad_side1, + ] + if data_format == "channels_first": + output = output.transpose((0, 4, 1, 2, 3)) + return output + + class ConvTransposeBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( { @@ -258,276 +528,6 @@ def test_bad_init_args(self): class ConvTransposeCorrectnessTest(testing.TestCase, parameterized.TestCase): - def np_conv1d_transpose( - self, - x, - kernel_weights, - bias_weights, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ): - if data_format == "channels_first": - x = x.transpose((0, 2, 1)) - if isinstance(strides, (tuple, list)): - h_stride = strides[0] - else: - h_stride = strides - if isinstance(dilation_rate, (tuple, list)): - h_dilation = dilation_rate[0] - else: - h_dilation = dilation_rate - - h_kernel, ch_out, ch_in = kernel_weights.shape - n_batch, h_x, _ = x.shape - # Get output shape and padding - _, h_out, _ = compute_conv_transpose_output_shape( - x.shape, - kernel_weights.shape, - ch_out, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ) - jax_padding = compute_conv_transpose_padding_args_for_jax( - input_shape=x.shape, - kernel_shape=kernel_weights.shape, - strides=strides, - padding=padding, - output_padding=output_padding, - dilation_rate=dilation_rate, - ) - h_pad_side1 = h_kernel - 1 - jax_padding[0][0] - - if h_dilation > 1: - # Increase kernel size - new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) - new_kenel_size_tuple = (new_h_kernel,) - new_kernel_weights = np.zeros( - (*new_kenel_size_tuple, ch_out, ch_in), - dtype=kernel_weights.dtype, - ) - new_kernel_weights[::h_dilation] = kernel_weights - kernel_weights = new_kernel_weights - h_kernel = kernel_weights.shape[0] - - # Compute output - output = np.zeros([n_batch, h_out + h_kernel, ch_out]) - for nb in range(n_batch): - for h_x_idx in range(h_x): - h_out_idx = h_x_idx * h_stride # Index in output - output[nb, h_out_idx : h_out_idx + h_kernel, :] += np.sum( - kernel_weights[:, :, :] * x[nb, h_x_idx, :], axis=-1 - ) - output = output + bias_weights - - # Cut padding results from output - output = output[:, h_pad_side1 : h_out + h_pad_side1] - if data_format == "channels_first": - output = output.transpose((0, 2, 1)) - return output - - def np_conv2d_transpose( - self, - x, - kernel_weights, - bias_weights, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 1)) - if isinstance(strides, (tuple, list)): - h_stride, w_stride = strides - else: - h_stride = strides - w_stride = strides - if isinstance(dilation_rate, (tuple, list)): - h_dilation, w_dilation = dilation_rate - else: - h_dilation = dilation_rate - w_dilation = dilation_rate - - h_kernel, w_kernel, ch_out, ch_in = kernel_weights.shape - n_batch, h_x, w_x, _ = x.shape - # Get output shape and padding - _, h_out, w_out, _ = compute_conv_transpose_output_shape( - x.shape, - kernel_weights.shape, - ch_out, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ) - jax_padding = compute_conv_transpose_padding_args_for_jax( - input_shape=x.shape, - kernel_shape=kernel_weights.shape, - strides=strides, - padding=padding, - output_padding=output_padding, - dilation_rate=dilation_rate, - ) - h_pad_side1 = h_kernel - 1 - jax_padding[0][0] - w_pad_side1 = w_kernel - 1 - jax_padding[1][0] - - if h_dilation > 1 or w_dilation > 1: - # Increase kernel size - new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) - new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) - new_kenel_size_tuple = (new_h_kernel, new_w_kernel) - new_kernel_weights = np.zeros( - (*new_kenel_size_tuple, ch_out, ch_in), - dtype=kernel_weights.dtype, - ) - new_kernel_weights[::h_dilation, ::w_dilation] = kernel_weights - kernel_weights = new_kernel_weights - h_kernel, w_kernel = kernel_weights.shape[:2] - - # Compute output - output = np.zeros([n_batch, h_out + h_kernel, w_out + w_kernel, ch_out]) - for nb in range(n_batch): - for h_x_idx in range(h_x): - h_out_idx = h_x_idx * h_stride # Index in output - for w_x_idx in range(w_x): - w_out_idx = w_x_idx * w_stride - output[ - nb, - h_out_idx : h_out_idx + h_kernel, - w_out_idx : w_out_idx + w_kernel, - :, - ] += np.sum( - kernel_weights[:, :, :, :] * x[nb, h_x_idx, w_x_idx, :], - axis=-1, - ) - output = output + bias_weights - - # Cut padding results from output - output = output[ - :, - h_pad_side1 : h_out + h_pad_side1, - w_pad_side1 : w_out + w_pad_side1, - ] - if data_format == "channels_first": - output = output.transpose((0, 3, 1, 2)) - return output - - def np_conv3d_transpose( - self, - x, - kernel_weights, - bias_weights, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 4, 1)) - if isinstance(strides, (tuple, list)): - h_stride, w_stride, d_stride = strides - else: - h_stride = strides - w_stride = strides - d_stride = strides - if isinstance(dilation_rate, (tuple, list)): - h_dilation, w_dilation, d_dilation = dilation_rate - else: - h_dilation = dilation_rate - w_dilation = dilation_rate - d_dilation = dilation_rate - - h_kernel, w_kernel, d_kernel, ch_out, ch_in = kernel_weights.shape - n_batch, h_x, w_x, d_x, _ = x.shape - # Get output shape and padding - _, h_out, w_out, d_out, _ = compute_conv_transpose_output_shape( - x.shape, - kernel_weights.shape, - ch_out, - strides, - padding, - output_padding, - data_format, - dilation_rate, - ) - jax_padding = compute_conv_transpose_padding_args_for_jax( - input_shape=x.shape, - kernel_shape=kernel_weights.shape, - strides=strides, - padding=padding, - output_padding=output_padding, - dilation_rate=dilation_rate, - ) - h_pad_side1 = h_kernel - 1 - jax_padding[0][0] - w_pad_side1 = w_kernel - 1 - jax_padding[1][0] - d_pad_side1 = d_kernel - 1 - jax_padding[2][0] - - if h_dilation > 1 or w_dilation > 1 or d_dilation > 1: - # Increase kernel size - new_h_kernel = h_kernel + (h_dilation - 1) * (h_kernel - 1) - new_w_kernel = w_kernel + (w_dilation - 1) * (w_kernel - 1) - new_d_kernel = d_kernel + (d_dilation - 1) * (d_kernel - 1) - new_kenel_size_tuple = (new_h_kernel, new_w_kernel, new_d_kernel) - new_kernel_weights = np.zeros( - (*new_kenel_size_tuple, ch_out, ch_in), - dtype=kernel_weights.dtype, - ) - new_kernel_weights[ - ::h_dilation, ::w_dilation, ::d_dilation - ] = kernel_weights - kernel_weights = new_kernel_weights - h_kernel, w_kernel, d_kernel = kernel_weights.shape[:3] - - # Compute output - output = np.zeros( - [ - n_batch, - h_out + h_kernel, - w_out + w_kernel, - d_out + d_kernel, - ch_out, - ] - ) - for nb in range(n_batch): - for h_x_idx in range(h_x): - h_out_idx = h_x_idx * h_stride # Index in output - for w_x_idx in range(w_x): - w_out_idx = w_x_idx * w_stride - for d_x_idx in range(d_x): - d_out_idx = d_x_idx * d_stride - output[ - nb, - h_out_idx : h_out_idx + h_kernel, - w_out_idx : w_out_idx + w_kernel, - d_out_idx : d_out_idx + d_kernel, - :, - ] += np.sum( - kernel_weights[:, :, :, :, :] - * x[nb, h_x_idx, w_x_idx, d_x_idx, :], - axis=-1, - ) - output = output + bias_weights - - # Cut padding results from output - output = output[ - :, - h_pad_side1 : h_out + h_pad_side1, - w_pad_side1 : w_out + w_pad_side1, - d_pad_side1 : d_out + d_pad_side1, - ] - if data_format == "channels_first": - output = output.transpose((0, 4, 1, 2, 3)) - return output - @parameterized.parameters( { "filters": 5, @@ -587,7 +587,7 @@ def test_conv1d_transpose( layer.bias.assign(bias_weights) outputs = layer(inputs) - expected = self.np_conv1d_transpose( + expected = np_conv1d_transpose( inputs, kernel_weights, bias_weights, @@ -667,7 +667,7 @@ def test_conv2d_transpose( layer.bias.assign(bias_weights) outputs = layer(inputs) - expected = self.np_conv2d_transpose( + expected = np_conv2d_transpose( inputs, kernel_weights, bias_weights, @@ -738,7 +738,7 @@ def test_conv3d_transpose( layer.bias.assign(bias_weights) outputs = layer(inputs) - expected = self.np_conv3d_transpose( + expected = np_conv3d_transpose( inputs, kernel_weights, bias_weights, @@ -776,7 +776,7 @@ def test_conv1d_transpose_consistency( ) # Exepected result - expected_res = self.np_conv1d_transpose( + expected_res = np_conv1d_transpose( x=input, kernel_weights=kernel_weights, bias_weights=np.zeros(shape=(1,)), diff --git a/keras/layers/pooling/average_pooling_test.py b/keras/layers/pooling/average_pooling_test.py index 436e2769f6e..6256876b6d6 100644 --- a/keras/layers/pooling/average_pooling_test.py +++ b/keras/layers/pooling/average_pooling_test.py @@ -8,6 +8,132 @@ from keras import testing +def _same_padding(input_size, pool_size, stride): + if input_size % stride == 0: + return max(pool_size - stride, 0) + else: + return max(pool_size - (input_size % stride), 0) + + +def np_avgpool1d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + if isinstance(pool_size, (tuple, list)): + pool_size = pool_size[0] + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + + if padding == "same": + n_batch, h_x, ch_x = x.shape + pad_value = _same_padding(h_x, pool_size, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, pad_value) + x = np.pad(x, pad_width=npad, mode="edge") + + n_batch, h_x, ch_x = x.shape + out_h = int((h_x - pool_size) / h_stride) + 1 + + stride_shape = (n_batch, out_h, ch_x, pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + x.strides[2], + x.strides[1], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(3,)) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_avgpool2d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + + h_pool_size, w_pool_size = pool_size + h_stride, w_stride = strides + if padding == "same": + n_batch, h_x, w_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + x = np.pad(x, pad_width=npad, mode="edge") + + n_batch, h_x, w_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + x.strides[3], + x.strides[1], + x.strides[2], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(4, 5)) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +def np_avgpool3d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + + h_pool_size, w_pool_size, d_pool_size = pool_size + h_stride, w_stride, d_stride = strides + + if padding == "same": + n_batch, h_x, w_x, d_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + d_padding = _same_padding(d_x, d_pool_size, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + npad[3] = (0, d_padding) + x = np.pad(x, pad_width=npad, mode="symmetric") + + n_batch, h_x, w_x, d_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + out_d = int((d_x - d_pool_size) / d_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + d_stride * x.strides[3], + x.strides[4], + x.strides[1], + x.strides[2], + x.strides[3], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.mean(windows, axis=(5, 6, 7)) + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + @pytest.mark.requires_trainable_backend class AveragePoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( @@ -111,128 +237,6 @@ def test_average_pooling3d( class AveragePoolingCorrectnessTest(testing.TestCase, parameterized.TestCase): - def _same_padding(self, input_size, pool_size, stride): - if input_size % stride == 0: - return max(pool_size - stride, 0) - else: - return max(pool_size - (input_size % stride), 0) - - def _np_avgpool1d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.swapaxes(1, 2) - if isinstance(pool_size, (tuple, list)): - pool_size = pool_size[0] - if isinstance(strides, (tuple, list)): - h_stride = strides[0] - else: - h_stride = strides - - if padding == "same": - n_batch, h_x, ch_x = x.shape - pad_value = self._same_padding(h_x, pool_size, h_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, pad_value) - x = np.pad(x, pad_width=npad, mode="edge") - - n_batch, h_x, ch_x = x.shape - out_h = int((h_x - pool_size) / h_stride) + 1 - - stride_shape = (n_batch, out_h, ch_x, pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - x.strides[2], - x.strides[1], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.mean(windows, axis=(3,)) - if data_format == "channels_first": - out = out.swapaxes(1, 2) - return out - - def _np_avgpool2d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 1)) - if isinstance(pool_size, int): - pool_size = (pool_size, pool_size) - if isinstance(strides, int): - strides = (strides, strides) - - h_pool_size, w_pool_size = pool_size - h_stride, w_stride = strides - if padding == "same": - n_batch, h_x, w_x, ch_x = x.shape - h_padding = self._same_padding(h_x, h_pool_size, h_stride) - w_padding = self._same_padding(w_x, w_pool_size, w_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, h_padding) - npad[2] = (0, w_padding) - x = np.pad(x, pad_width=npad, mode="edge") - - n_batch, h_x, w_x, ch_x = x.shape - out_h = int((h_x - h_pool_size) / h_stride) + 1 - out_w = int((w_x - w_pool_size) / w_stride) + 1 - - stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - w_stride * x.strides[2], - x.strides[3], - x.strides[1], - x.strides[2], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.mean(windows, axis=(4, 5)) - if data_format == "channels_first": - out = out.transpose((0, 3, 1, 2)) - return out - - def _np_avgpool3d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 4, 1)) - - if isinstance(pool_size, int): - pool_size = (pool_size, pool_size, pool_size) - if isinstance(strides, int): - strides = (strides, strides, strides) - - h_pool_size, w_pool_size, d_pool_size = pool_size - h_stride, w_stride, d_stride = strides - - if padding == "same": - n_batch, h_x, w_x, d_x, ch_x = x.shape - h_padding = self._same_padding(h_x, h_pool_size, h_stride) - w_padding = self._same_padding(w_x, w_pool_size, w_stride) - d_padding = self._same_padding(d_x, d_pool_size, d_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, h_padding) - npad[2] = (0, w_padding) - npad[3] = (0, d_padding) - x = np.pad(x, pad_width=npad, mode="symmetric") - - n_batch, h_x, w_x, d_x, ch_x = x.shape - out_h = int((h_x - h_pool_size) / h_stride) + 1 - out_w = int((w_x - w_pool_size) / w_stride) + 1 - out_d = int((d_x - d_pool_size) / d_stride) + 1 - - stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - w_stride * x.strides[2], - d_stride * x.strides[3], - x.strides[4], - x.strides[1], - x.strides[2], - x.strides[3], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.mean(windows, axis=(5, 6, 7)) - if data_format == "channels_first": - out = out.transpose((0, 4, 1, 2, 3)) - return out - @parameterized.parameters( (2, 1, "valid", "channels_last"), (2, 1, "valid", "channels_first"), @@ -249,7 +253,7 @@ def test_average_pooling1d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool1d( + expected = np_avgpool1d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -276,7 +280,7 @@ def test_average_pooling1d_same_padding( data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool1d( + expected = np_avgpool1d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -294,7 +298,7 @@ def test_average_pooling2d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool2d( + expected = np_avgpool2d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -320,7 +324,7 @@ def test_average_pooling2d_same_padding( data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool2d( + expected = np_avgpool2d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -341,7 +345,7 @@ def test_average_pooling3d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool3d( + expected = np_avgpool3d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -368,7 +372,7 @@ def test_average_pooling3d_same_padding( data_format=data_format, ) outputs = layer(inputs) - expected = self._np_avgpool3d( + expected = np_avgpool3d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) diff --git a/keras/layers/pooling/max_pooling_test.py b/keras/layers/pooling/max_pooling_test.py index 7f8dd437a83..418a77f8327 100644 --- a/keras/layers/pooling/max_pooling_test.py +++ b/keras/layers/pooling/max_pooling_test.py @@ -7,6 +7,132 @@ from keras import testing +def _same_padding(input_size, pool_size, stride): + if input_size % stride == 0: + return max(pool_size - stride, 0) + else: + return max(pool_size - (input_size % stride), 0) + + +def np_maxpool1d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.swapaxes(1, 2) + if isinstance(pool_size, (tuple, list)): + pool_size = pool_size[0] + if isinstance(strides, (tuple, list)): + h_stride = strides[0] + else: + h_stride = strides + + if padding == "same": + n_batch, h_x, ch_x = x.shape + pad_value = _same_padding(h_x, pool_size, h_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, pad_value) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, ch_x = x.shape + out_h = int((h_x - pool_size) / h_stride) + 1 + + stride_shape = (n_batch, out_h, ch_x, pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + x.strides[2], + x.strides[1], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(3,)) + if data_format == "channels_first": + out = out.swapaxes(1, 2) + return out + + +def np_maxpool2d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 1)) + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides) + + h_pool_size, w_pool_size = pool_size + h_stride, w_stride = strides + if padding == "same": + n_batch, h_x, w_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, w_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + x.strides[3], + x.strides[1], + x.strides[2], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(4, 5)) + if data_format == "channels_first": + out = out.transpose((0, 3, 1, 2)) + return out + + +def np_maxpool3d(x, pool_size, strides, padding, data_format): + if data_format == "channels_first": + x = x.transpose((0, 2, 3, 4, 1)) + + if isinstance(pool_size, int): + pool_size = (pool_size, pool_size, pool_size) + if isinstance(strides, int): + strides = (strides, strides, strides) + + h_pool_size, w_pool_size, d_pool_size = pool_size + h_stride, w_stride, d_stride = strides + + if padding == "same": + n_batch, h_x, w_x, d_x, ch_x = x.shape + h_padding = _same_padding(h_x, h_pool_size, h_stride) + w_padding = _same_padding(w_x, w_pool_size, w_stride) + d_padding = _same_padding(d_x, d_pool_size, d_stride) + npad = [(0, 0)] * x.ndim + npad[1] = (0, h_padding) + npad[2] = (0, w_padding) + npad[3] = (0, d_padding) + x = np.pad(x, pad_width=npad, mode="constant", constant_values=-np.inf) + + n_batch, h_x, w_x, d_x, ch_x = x.shape + out_h = int((h_x - h_pool_size) / h_stride) + 1 + out_w = int((w_x - w_pool_size) / w_stride) + 1 + out_d = int((d_x - d_pool_size) / d_stride) + 1 + + stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) + strides = ( + x.strides[0], + h_stride * x.strides[1], + w_stride * x.strides[2], + d_stride * x.strides[3], + x.strides[4], + x.strides[1], + x.strides[2], + x.strides[3], + ) + windows = as_strided(x, shape=stride_shape, strides=strides) + out = np.max(windows, axis=(5, 6, 7)) + if data_format == "channels_first": + out = out.transpose((0, 4, 1, 2, 3)) + return out + + @pytest.mark.requires_trainable_backend class MaxPoolingBasicTest(testing.TestCase, parameterized.TestCase): @parameterized.parameters( @@ -110,134 +236,6 @@ def test_max_pooling3d( class MaxPoolingCorrectnessTest(testing.TestCase, parameterized.TestCase): - def _same_padding(self, input_size, pool_size, stride): - if input_size % stride == 0: - return max(pool_size - stride, 0) - else: - return max(pool_size - (input_size % stride), 0) - - def _np_maxpool1d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.swapaxes(1, 2) - if isinstance(pool_size, (tuple, list)): - pool_size = pool_size[0] - if isinstance(strides, (tuple, list)): - h_stride = strides[0] - else: - h_stride = strides - - if padding == "same": - n_batch, h_x, ch_x = x.shape - pad_value = self._same_padding(h_x, pool_size, h_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, pad_value) - x = np.pad( - x, pad_width=npad, mode="constant", constant_values=-np.inf - ) - - n_batch, h_x, ch_x = x.shape - out_h = int((h_x - pool_size) / h_stride) + 1 - - stride_shape = (n_batch, out_h, ch_x, pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - x.strides[2], - x.strides[1], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.max(windows, axis=(3,)) - if data_format == "channels_first": - out = out.swapaxes(1, 2) - return out - - def _np_maxpool2d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 1)) - if isinstance(pool_size, int): - pool_size = (pool_size, pool_size) - if isinstance(strides, int): - strides = (strides, strides) - - h_pool_size, w_pool_size = pool_size - h_stride, w_stride = strides - if padding == "same": - n_batch, h_x, w_x, ch_x = x.shape - h_padding = self._same_padding(h_x, h_pool_size, h_stride) - w_padding = self._same_padding(w_x, w_pool_size, w_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, h_padding) - npad[2] = (0, w_padding) - x = np.pad( - x, pad_width=npad, mode="constant", constant_values=-np.inf - ) - - n_batch, h_x, w_x, ch_x = x.shape - out_h = int((h_x - h_pool_size) / h_stride) + 1 - out_w = int((w_x - w_pool_size) / w_stride) + 1 - - stride_shape = (n_batch, out_h, out_w, ch_x, *pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - w_stride * x.strides[2], - x.strides[3], - x.strides[1], - x.strides[2], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.max(windows, axis=(4, 5)) - if data_format == "channels_first": - out = out.transpose((0, 3, 1, 2)) - return out - - def _np_maxpool3d(self, x, pool_size, strides, padding, data_format): - if data_format == "channels_first": - x = x.transpose((0, 2, 3, 4, 1)) - - if isinstance(pool_size, int): - pool_size = (pool_size, pool_size, pool_size) - if isinstance(strides, int): - strides = (strides, strides, strides) - - h_pool_size, w_pool_size, d_pool_size = pool_size - h_stride, w_stride, d_stride = strides - - if padding == "same": - n_batch, h_x, w_x, d_x, ch_x = x.shape - h_padding = self._same_padding(h_x, h_pool_size, h_stride) - w_padding = self._same_padding(w_x, w_pool_size, w_stride) - d_padding = self._same_padding(d_x, d_pool_size, d_stride) - npad = [(0, 0)] * x.ndim - npad[1] = (0, h_padding) - npad[2] = (0, w_padding) - npad[3] = (0, d_padding) - x = np.pad( - x, pad_width=npad, mode="constant", constant_values=-np.inf - ) - - n_batch, h_x, w_x, d_x, ch_x = x.shape - out_h = int((h_x - h_pool_size) / h_stride) + 1 - out_w = int((w_x - w_pool_size) / w_stride) + 1 - out_d = int((d_x - d_pool_size) / d_stride) + 1 - - stride_shape = (n_batch, out_h, out_w, out_d, ch_x, *pool_size) - strides = ( - x.strides[0], - h_stride * x.strides[1], - w_stride * x.strides[2], - d_stride * x.strides[3], - x.strides[4], - x.strides[1], - x.strides[2], - x.strides[3], - ) - windows = as_strided(x, shape=stride_shape, strides=strides) - out = np.max(windows, axis=(5, 6, 7)) - if data_format == "channels_first": - out = out.transpose((0, 4, 1, 2, 3)) - return out - @parameterized.parameters( (2, 1, "valid", "channels_last"), (2, 1, "valid", "channels_first"), @@ -256,7 +254,7 @@ def test_max_pooling1d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_maxpool1d( + expected = np_maxpool1d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -278,7 +276,7 @@ def test_max_pooling2d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_maxpool2d( + expected = np_maxpool2d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) @@ -299,7 +297,7 @@ def test_max_pooling3d(self, pool_size, strides, padding, data_format): data_format=data_format, ) outputs = layer(inputs) - expected = self._np_maxpool3d( + expected = np_maxpool3d( inputs, pool_size, strides, padding, data_format ) self.assertAllClose(outputs, expected) diff --git a/keras/ops/nn_test.py b/keras/ops/nn_test.py index 12d39832aed..b988a8d4d9e 100644 --- a/keras/ops/nn_test.py +++ b/keras/ops/nn_test.py @@ -1,11 +1,20 @@ import numpy as np import pytest -import tensorflow as tf from absl.testing import parameterized from keras import backend from keras import testing from keras.backend.common.keras_tensor import KerasTensor +from keras.layers.convolutional.conv_test import np_conv1d +from keras.layers.convolutional.conv_test import np_conv2d +from keras.layers.convolutional.conv_test import np_conv3d +from keras.layers.convolutional.conv_transpose_test import np_conv1d_transpose +from keras.layers.convolutional.conv_transpose_test import np_conv2d_transpose +from keras.layers.convolutional.depthwise_conv_test import np_depthwise_conv2d +from keras.layers.pooling.average_pooling_test import np_avgpool1d +from keras.layers.pooling.average_pooling_test import np_avgpool2d +from keras.layers.pooling.max_pooling_test import np_maxpool1d +from keras.layers.pooling.max_pooling_test import np_maxpool2d from keras.ops import nn as knn @@ -788,22 +797,24 @@ def test_max_pool(self): x = np.arange(120, dtype=float).reshape([2, 20, 3]) self.assertAllClose( knn.max_pool(x, 2, 1, padding="valid"), - tf.nn.max_pool1d(x, 2, 1, padding="VALID"), + np_maxpool1d(x, 2, 1, padding="valid", data_format="channels_last"), ) self.assertAllClose( knn.max_pool(x, 2, 2, padding="same"), - tf.nn.max_pool1d(x, 2, 2, padding="SAME"), + np_maxpool1d(x, 2, 2, padding="same", data_format="channels_last"), ) # Test 2D max pooling. x = np.arange(540, dtype=float).reshape([2, 10, 9, 3]) self.assertAllClose( knn.max_pool(x, 2, 1, padding="valid"), - tf.nn.max_pool2d(x, 2, 1, padding="VALID"), + np_maxpool2d(x, 2, 1, padding="valid", data_format="channels_last"), ) self.assertAllClose( knn.max_pool(x, 2, (2, 1), padding="same"), - tf.nn.max_pool2d(x, 2, (2, 1), padding="SAME"), + np_maxpool2d( + x, 2, (2, 1), padding="same", data_format="channels_last" + ), ) def test_average_pool_valid_padding(self): @@ -811,14 +822,14 @@ def test_average_pool_valid_padding(self): x = np.arange(120, dtype=float).reshape([2, 20, 3]) self.assertAllClose( knn.average_pool(x, 2, 1, padding="valid"), - tf.nn.avg_pool1d(x, 2, 1, padding="VALID"), + np_avgpool1d(x, 2, 1, padding="valid", data_format="channels_last"), ) # Test 2D max pooling. x = np.arange(540, dtype=float).reshape([2, 10, 9, 3]) self.assertAllClose( knn.average_pool(x, 2, 1, padding="valid"), - tf.nn.avg_pool2d(x, 2, 1, padding="VALID"), + np_avgpool2d(x, 2, 1, padding="valid", data_format="channels_last"), ) @pytest.mark.skipif( @@ -830,14 +841,16 @@ def test_average_pool_same_padding(self): x = np.arange(120, dtype=float).reshape([2, 20, 3]) self.assertAllClose( knn.average_pool(x, 2, 2, padding="same"), - tf.nn.avg_pool1d(x, 2, 2, padding="SAME"), + np_avgpool1d(x, 2, 2, padding="same", data_format="channels_last"), ) # Test 2D max pooling. x = np.arange(540, dtype=float).reshape([2, 10, 9, 3]) self.assertAllClose( knn.average_pool(x, 2, (2, 1), padding="same"), - tf.nn.avg_pool2d(x, 2, (2, 1), padding="SAME"), + np_avgpool2d( + x, 2, (2, 1), padding="same", data_format="channels_last" + ), ) @parameterized.product( @@ -859,12 +872,15 @@ def test_conv_1d(self, strides, padding, dilation_rate): padding=padding, dilation_rate=dilation_rate, ) - expected = tf.nn.conv1d( + expected = np_conv1d( inputs_1d, kernel, - strides, - padding=padding.upper(), - dilations=dilation_rate, + bias_weights=np.zeros((2,)), + strides=strides, + padding=padding.lower(), + data_format="channels_last", + dilation_rate=dilation_rate, + groups=1, ) self.assertAllClose(outputs, expected) @@ -873,19 +889,55 @@ def test_conv_2d(self): kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) outputs = knn.conv(inputs_2d, kernel, 1, padding="valid") - expected = tf.nn.conv2d(inputs_2d, kernel, 1, padding="VALID") + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((2,)), + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + groups=1, + ) self.assertAllClose(outputs, expected) outputs = knn.conv(inputs_2d, kernel, (1, 2), padding="valid") - expected = tf.nn.conv2d(inputs_2d, kernel, (1, 2), padding="VALID") + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((2,)), + strides=(1, 2), + padding="valid", + data_format="channels_last", + dilation_rate=1, + groups=1, + ) self.assertAllClose(outputs, expected) outputs = knn.conv(inputs_2d, kernel, (1, 2), padding="same") - expected = tf.nn.conv2d(inputs_2d, kernel, (1, 2), padding="SAME") + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((2,)), + strides=(1, 2), + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, + ) self.assertAllClose(outputs, expected) outputs = knn.conv(inputs_2d, kernel, 2, padding="same") - expected = tf.nn.conv2d(inputs_2d, kernel, 2, padding="SAME") + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((2,)), + strides=2, + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, + ) self.assertAllClose(outputs, expected) # Test group > 1. @@ -894,8 +946,15 @@ def test_conv_2d(self): outputs = knn.conv( inputs_2d, kernel, 2, padding="same", dilation_rate=1 ) - expected = tf.nn.conv2d( - inputs_2d, kernel, 2, padding="SAME", dilations=1 + expected = np_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=2, + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected) @@ -906,12 +965,15 @@ def test_conv_2d(self): padding="same", dilation_rate=(2, 1), ) - expected = tf.nn.conv2d( + expected = np_conv2d( inputs_2d, kernel, - 1, - padding="SAME", - dilations=(2, 1), + bias_weights=np.zeros((6,)), + strides=1, + padding="same", + data_format="channels_last", + dilation_rate=(2, 1), + groups=1, ) self.assertAllClose(outputs, expected) @@ -920,8 +982,15 @@ def test_conv_3d(self): kernel = np.arange(162, dtype=float).reshape([3, 3, 3, 3, 2]) outputs = knn.conv(inputs_3d, kernel, 1, padding="valid") - expected = tf.nn.conv3d( - inputs_3d, kernel, (1, 1, 1, 1, 1), padding="VALID" + expected = np_conv3d( + inputs_3d, + kernel, + bias_weights=np.zeros((2,)), + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) @@ -932,24 +1001,41 @@ def test_conv_3d(self): padding="valid", dilation_rate=(1, 1, 1), ) - expected = tf.nn.conv3d( + expected = np_conv3d( inputs_3d, kernel, - (1, 1, 1, 1, 1), - padding="VALID", - dilations=(1, 1, 1, 1, 1), + bias_weights=np.zeros((2,)), + strides=(1, 1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=(1, 1, 1), + groups=1, ) self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) outputs = knn.conv(inputs_3d, kernel, 2, padding="valid") - expected = tf.nn.conv3d( - inputs_3d, kernel, (1, 2, 2, 2, 1), padding="VALID" + expected = np_conv3d( + inputs_3d, + kernel, + bias_weights=np.zeros((2,)), + strides=2, + padding="valid", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) outputs = knn.conv(inputs_3d, kernel, 2, padding="same") - expected = tf.nn.conv3d( - inputs_3d, kernel, (1, 2, 2, 2, 1), padding="SAME" + expected = np_conv3d( + inputs_3d, + kernel, + bias_weights=np.zeros((2,)), + strides=2, + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected, rtol=1e-5, atol=1e-5) @@ -958,28 +1044,52 @@ def test_depthwise_conv_2d(self): kernel = np.arange(24, dtype=float).reshape([2, 2, 3, 2]) outputs = knn.depthwise_conv(inputs_2d, kernel, 1, padding="valid") - expected = tf.nn.depthwise_conv2d( - inputs_2d, kernel, (1, 1, 1, 1), padding="VALID" + expected = np_depthwise_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) outputs = knn.depthwise_conv(inputs_2d, kernel, (1, 1), padding="valid") - expected = tf.nn.depthwise_conv2d( - inputs_2d, kernel, (1, 1, 1, 1), padding="VALID" + expected = np_depthwise_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=(1, 1), + padding="valid", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) outputs = knn.depthwise_conv(inputs_2d, kernel, (2, 2), padding="same") - expected = tf.nn.depthwise_conv2d( - inputs_2d, kernel, (1, 2, 2, 1), padding="SAME" + expected = np_depthwise_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=(2, 2), + padding="same", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) outputs = knn.depthwise_conv( inputs_2d, kernel, 1, padding="same", dilation_rate=(2, 2) ) - expected = tf.nn.depthwise_conv2d( - inputs_2d, kernel, (1, 1, 1, 1), padding="SAME", dilations=(2, 2) + expected = np_depthwise_conv2d( + inputs_2d, + kernel, + bias_weights=np.zeros((6,)), + strides=1, + padding="same", + data_format="channels_last", + dilation_rate=(2, 2), ) self.assertAllClose(outputs, expected) @@ -992,12 +1102,25 @@ def test_separable_conv_2d(self): outputs = knn.separable_conv( inputs_2d, depthwise_kernel, pointwise_kernel, 1, padding="valid" ) - expected = tf.nn.separable_conv2d( + # Depthwise followed by pointwise conv + expected_depthwise = np_depthwise_conv2d( inputs_2d, depthwise_kernel, + np.zeros(6), + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + ) + expected = np_conv2d( + expected_depthwise, pointwise_kernel, - (1, 1, 1, 1), - padding="VALID", + np.zeros(6 * 12), + strides=1, + padding="valid", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected) @@ -1008,24 +1131,30 @@ def test_separable_conv_2d(self): (1, 1), padding="valid", ) - expected = tf.nn.separable_conv2d( - inputs_2d, - depthwise_kernel, - pointwise_kernel, - (1, 1, 1, 1), - padding="VALID", - ) self.assertAllClose(outputs, expected) outputs = knn.separable_conv( inputs_2d, depthwise_kernel, pointwise_kernel, 2, padding="same" ) - expected = tf.nn.separable_conv2d( + # Depthwise followed by pointwise conv + expected_depthwise = np_depthwise_conv2d( inputs_2d, depthwise_kernel, + np.zeros(6), + strides=2, + padding="same", + data_format="channels_last", + dilation_rate=1, + ) + expected = np_conv2d( + expected_depthwise, pointwise_kernel, - (1, 2, 2, 1), - padding="SAME", + np.zeros(6 * 12), + strides=1, + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected) @@ -1037,13 +1166,25 @@ def test_separable_conv_2d(self): padding="same", dilation_rate=(2, 2), ) - expected = tf.nn.separable_conv2d( + # Depthwise followed by pointwise conv + expected_depthwise = np_depthwise_conv2d( inputs_2d, depthwise_kernel, + np.zeros(6), + strides=1, + padding="same", + data_format="channels_last", + dilation_rate=(2, 2), + ) + expected = np_conv2d( + expected_depthwise, pointwise_kernel, - (1, 1, 1, 1), - padding="SAME", - dilations=(2, 2), + np.zeros(6 * 12), + strides=1, + padding="same", + data_format="channels_last", + dilation_rate=1, + groups=1, ) self.assertAllClose(outputs, expected) @@ -1051,14 +1192,28 @@ def test_conv_transpose_1d(self): inputs_1d = np.arange(24, dtype=float).reshape([2, 4, 3]) kernel = np.arange(30, dtype=float).reshape([2, 5, 3]) outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="valid") - expected = tf.nn.conv_transpose( - inputs_1d, kernel, [2, 8, 5], 2, padding="VALID" + expected = np_conv1d_transpose( + inputs_1d, + kernel, + bias_weights=np.zeros(5), + strides=2, + output_padding=None, + padding="valid", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) outputs = knn.conv_transpose(inputs_1d, kernel, 2, padding="same") - expected = tf.nn.conv_transpose( - inputs_1d, kernel, [2, 8, 5], 2, padding="SAME" + expected = np_conv1d_transpose( + inputs_1d, + kernel, + bias_weights=np.zeros(5), + strides=2, + output_padding=None, + padding="same", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) @@ -1067,46 +1222,70 @@ def test_conv_transpose_2d(self): kernel = np.arange(60, dtype=float).reshape([2, 2, 5, 3]) outputs = knn.conv_transpose(inputs_2d, kernel, (2, 2), padding="valid") - expected = tf.nn.conv_transpose( - inputs_2d, kernel, [2, 8, 8, 5], (2, 2), padding="VALID" + expected = np_conv2d_transpose( + inputs_2d, + kernel, + bias_weights=np.zeros(5), + strides=(2, 2), + output_padding=None, + padding="valid", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) outputs = knn.conv_transpose(inputs_2d, kernel, 2, padding="same") - expected = tf.nn.conv_transpose( - inputs_2d, kernel, [2, 8, 8, 5], 2, padding="SAME" + expected = np_conv2d_transpose( + inputs_2d, + kernel, + bias_weights=np.zeros(5), + strides=2, + output_padding=None, + padding="same", + data_format="channels_last", + dilation_rate=1, ) self.assertAllClose(outputs, expected) def test_one_hot(self): # Test 1D one-hot. indices_1d = np.array([0, 1, 2, 3]) - self.assertAllClose( - knn.one_hot(indices_1d, 4), tf.one_hot(indices_1d, 4) - ) + self.assertAllClose(knn.one_hot(indices_1d, 4), np.eye(4)[indices_1d]) self.assertAllClose( knn.one_hot(indices_1d, 4, axis=0), - tf.one_hot(indices_1d, 4, axis=0), + np.eye(4)[indices_1d], ) # Test 2D one-hot. indices_2d = np.array([[0, 1], [2, 3]]) - self.assertAllClose( - knn.one_hot(indices_2d, 4), tf.one_hot(indices_2d, 4) - ) + self.assertAllClose(knn.one_hot(indices_2d, 4), np.eye(4)[indices_2d]) self.assertAllClose( knn.one_hot(indices_2d, 4, axis=2), - tf.one_hot(indices_2d, 4, axis=2), + np.eye(4)[indices_2d], ) self.assertAllClose( knn.one_hot(indices_2d, 4, axis=1), - tf.one_hot(indices_2d, 4, axis=1), + np.transpose(np.eye(4)[indices_2d], (0, 2, 1)), ) # Test 1D one-hot with negative inputs indices_1d = np.array([0, -1, -1, 3]) self.assertAllClose( - knn.one_hot(indices_1d, 4), tf.one_hot(indices_1d, 4) + knn.one_hot(indices_1d, 4), + np.array( + [ + [1, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [ + 0, + 0, + 0, + 1, + ], + ], + dtype=np.float32, + ), ) def test_binary_crossentropy(self):