Skip to content

Commit

Permalink
Add missing tests for ops/math.py (keras-team#19493)
Browse files Browse the repository at this point in the history
* addtest1

* fix2

* fix3

* fix4
  • Loading branch information
Faisal-Alsrheed authored Apr 11, 2024
1 parent 37d7900 commit 5d99879
Showing 1 changed file with 90 additions and 0 deletions.
90 changes: 90 additions & 0 deletions keras/ops/math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1166,3 +1166,93 @@ def test_istft_low_rank_input(self):
low_rank_input = np.random.rand(3)
with self.assertRaisesRegex(ValueError, "Input should have rank >= 2"):
istft_op.compute_output_spec((low_rank_input, low_rank_input))

def test_input_not_tuple_or_list_raises_error(self):
irfft_op = kmath.IRFFT()
invalid_input = np.array([1, 2, 3])
with self.assertRaisesRegex(
ValueError, "Input `x` should be a tuple of two tensors"
):
irfft_op.compute_output_spec(invalid_input)

def test_input_tuple_with_less_than_two_elements_raises_error(self):
irfft_op = kmath.IRFFT()
too_short_input = (np.array([1, 2, 3]),)
with self.assertRaisesRegex(
ValueError, "Input `x` should be a tuple of two tensors"
):
irfft_op.compute_output_spec(too_short_input)

def test_input_tuple_with_more_than_two_elements_raises_error(self):
irfft_op = kmath.IRFFT()
too_long_input = (
np.array([1, 2, 3]),
np.array([4, 5, 6]),
np.array([7, 8, 9]),
)
with self.assertRaisesRegex(
ValueError, "Input `x` should be a tuple of two tensors"
):
irfft_op.compute_output_spec(too_long_input)

def test_mismatched_shapes_input_validation(self):
irfft_op = kmath.IRFFT()

# Create real and imaginary parts with mismatched shapes
real_part = np.array([1, 2, 3])
imag_part = np.array([[1, 2], [3, 4]])

with self.assertRaisesRegex(
ValueError,
"Both the real and imaginary parts should have the same shape",
):
irfft_op.compute_output_spec((real_part, imag_part))

def test_insufficient_rank_input_validation(self):
irfft_op = kmath.IRFFT()

# Create real and imaginary parts with insufficient rank (0D)
real_part = np.array(1)
imag_part = np.array(1)

with self.assertRaisesRegex(ValueError, "Input should have rank >= 1"):
irfft_op.compute_output_spec((real_part, imag_part))

def test_with_specified_fft_length(self):
fft_length = 10
irfft_op = kmath.IRFFT(fft_length=fft_length)

real_part = np.random.rand(4, 8)
imag_part = np.random.rand(4, 8)

expected_shape = real_part.shape[:-1] + (fft_length,)
output_shape = irfft_op.compute_output_spec(
(real_part, imag_part)
).shape

self.assertEqual(output_shape, expected_shape)

def test_inferred_fft_length_with_defined_last_dimension(self):
irfft_op = kmath.IRFFT()

real_part = np.random.rand(4, 8)
imag_part = np.random.rand(4, 8)

inferred_fft_length = 2 * (real_part.shape[-1] - 1)
expected_shape = real_part.shape[:-1] + (inferred_fft_length,)
output_shape = irfft_op.compute_output_spec(
(real_part, imag_part)
).shape

self.assertEqual(output_shape, expected_shape)

def test_undefined_fft_length_and_last_dimension(self):
irfft_op = kmath.IRFFT()

real_part = KerasTensor(shape=(4, None), dtype="float32")
imag_part = KerasTensor(shape=(4, None), dtype="float32")

output_spec = irfft_op.compute_output_spec((real_part, imag_part))
expected_shape = real_part.shape[:-1] + (None,)

self.assertEqual(output_spec.shape, expected_shape)

0 comments on commit 5d99879

Please sign in to comment.