Skip to content

Commit

Permalink
Merge pull request sony#71 from sony/fix_issue_68
Browse files Browse the repository at this point in the history
Fix issue 68
  • Loading branch information
elad-c authored Nov 9, 2023
2 parents 16895f2 + 882f20c commit 7ce4bfd
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def __init__(self,
if per_channel:
assert input_rank is not None, f'Input rank is missing in per channel quantization'
assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
assert -input_rank <= channel_axis < input_rank, \
f'Channel axis out of range. Must be {-input_rank} <= channel_axis < {input_rank}'
assert len(threshold) >= 1, f'In per-channel quantization threshold list should be of length >= 1 ' \
f'but is {len(threshold)} '
else:
Expand Down Expand Up @@ -112,6 +114,7 @@ def __init__(self,
# If per-channel quantization is being used and the channel axis is not the last axis,
# create a permutation vector to move the channel axis to the last position
self.perm_vec = list(np.arange(self.input_rank))
channel_axis = self.perm_vec[self.channel_axis]
self.perm_vec[channel_axis] = self.input_rank - 1
self.perm_vec[self.input_rank - 1] = channel_axis
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,19 @@ def __init__(self,
if per_channel:
assert input_rank is not None, f'Input rank is missing in per channel quantization'
assert channel_axis is not None, f'Channel axis is missing in per channel quantization'
assert len(self.min_range_np) >= 1, f'In per-channel quantization min ranges list should be of length >= 1 but is {len(self.min_range_np)}'
assert len(self.max_range_np) >= 1, f'In per-channel quantization max ranges list should be of length >= 1 but is {len(self.max_range_np)}'
assert -input_rank <= channel_axis < input_rank, \
f'Channel axis out of range. Must be {-input_rank} <= channel_axis < {input_rank}'
assert len(self.min_range_np) >= 1, \
f'In per-channel quantization min ranges list should be of length >= 1 but is {len(self.min_range_np)}'
assert len(self.max_range_np) >= 1, \
f'In per-channel quantization max ranges list should be of length >= 1 but is {len(self.max_range_np)}'
else:
assert len(self.min_range_np) == 1, f'In per-tensor quantization min/max should be of length 1 but is {len(self.min_range)}'
assert len(self.min_range_np) == 1, f'In per-tensor quantization min_range should be of length 1 but is {len(self.min_range_np)}'
assert len(self.max_range_np) == 1, f'In per-tensor quantization max_range should be of length 1 but is {len(self.max_range_np)}'
assert len(self.min_range_np) == 1, \
f'In per-tensor quantization min/max should be of length 1 but is {len(self.min_range)}'
assert len(self.min_range_np) == 1, \
f'In per-tensor quantization min_range should be of length 1 but is {len(self.min_range_np)}'
assert len(self.max_range_np) == 1, \
f'In per-tensor quantization max_range should be of length 1 but is {len(self.max_range_np)}'
self.min_range_np = self.min_range_np[0]
self.max_range_np = self.max_range_np[0]

Expand All @@ -97,6 +104,7 @@ def __init__(self,
# If per-channel quantization is being used and the channel axis is not the last axis,
# create a permutation vector to move the channel axis to the last position
self.perm_vec = list(np.arange(self.input_rank))
channel_axis = self.perm_vec[self.channel_axis]
self.perm_vec[channel_axis] = self.input_rank - 1
self.perm_vec[self.input_rank - 1] = channel_axis
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,31 @@ def test_missing_input_rank_uniform_quantizer(self):
max_range=[4., 3.],
channel_axis=1)
self.assertEqual('Input rank is missing in per channel quantization', str(e.exception))

def test_out_of_range_channel_axis_POT_quantizer(self):
with self.assertRaises(Exception) as e:
WeightsPOTInferableQuantizer(num_bits=8,
per_channel=True,
threshold=[3., 2.],
channel_axis=-6,
input_rank=4)
self.assertEqual('Channel axis out of range. Must be -4 <= channel_axis < 4', str(e.exception))

def test_out_of_range_channel_axis_symmetric_quantizer(self):
with self.assertRaises(Exception) as e:
WeightsSymmetricInferableQuantizer(num_bits=8,
per_channel=True,
threshold=[3., 2.],
channel_axis=-6,
input_rank=4)
self.assertEqual('Channel axis out of range. Must be -4 <= channel_axis < 4', str(e.exception))

def test_out_of_range_channel_axis_uniform_quantizer(self):
with self.assertRaises(Exception) as e:
WeightsUniformInferableQuantizer(num_bits=8,
per_channel=True,
min_range=[3., 2.],
max_range=[4., 3.],
channel_axis=-6,
input_rank=4)
self.assertEqual('Channel axis out of range. Must be -4 <= channel_axis < 4', str(e.exception))
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,23 @@ def test_missing_input_rank_pot_lut_quantizer(self):
lut_values=np.asarray([-25, 25]),
channel_axis=None,
per_channel=True)

def test_out_of_range_channel_axis_lut_pot_quantizer(self):
with self.assertRaises(Exception) as e:
WeightsLUTPOTInferableQuantizer(num_bits=2,
lut_values=list(range(4)),
per_channel=True,
threshold=[3., 2.],
channel_axis=-6,
input_rank=4)
self.assertEqual('Channel axis out of range. Must be -4 <= channel_axis < 4', str(e.exception))

def test_out_of_range_channel_axis_lut_symmetric_quantizer(self):
with self.assertRaises(Exception) as e:
WeightsLUTSymmetricInferableQuantizer(num_bits=2,
lut_values=list(range(4)),
per_channel=True,
threshold=[3., 2.],
channel_axis=-6,
input_rank=4)
self.assertEqual('Channel axis out of range. Must be -4 <= channel_axis < 4', str(e.exception))
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,34 @@ def test_uniform_weights_quantizer_zero_not_in_range(self):
self.assertTrue(0 in np.unique(channel_slice_i),
f'zero should be in quantization range, but quantized values are in set: '
f'{np.unique(channel_slice_i)}')

def test_negative_channel_axis_POT_quantizer(self):
quantizer = WeightsPOTInferableQuantizer(num_bits=8,
per_channel=True,
threshold=[1.] * 4,
channel_axis=-2,
input_rank=4)
input_tensor = tf.constant(np.random.rand(2, 3, 4, 5), dtype=tf.float32)
fake_quantized_tensor = quantizer(input_tensor)
self.assertTrue(np.linalg.norm(fake_quantized_tensor - input_tensor) < 0.04)

def test_negative_channel_axis_symmetric_quantizer(self):
quantizer = WeightsSymmetricInferableQuantizer(num_bits=8,
per_channel=True,
threshold=[0.99] * 3,
channel_axis=-3,
input_rank=4)
input_tensor = tf.constant(np.random.rand(2, 3, 4, 5), dtype=tf.float32)
fake_quantized_tensor = quantizer(input_tensor)
self.assertTrue(np.linalg.norm(fake_quantized_tensor - input_tensor) < 0.04)

def test_negative_channel_axis_uniform_quantizer(self):
quantizer = WeightsUniformInferableQuantizer(num_bits=8,
per_channel=True,
min_range=[-0.99, -0.99],
max_range=[0.99, 0.99],
channel_axis=-4,
input_rank=4)
input_tensor = tf.constant(np.random.rand(2, 3, 4, 5), dtype=tf.float32)
fake_quantized_tensor = quantizer(input_tensor)
self.assertTrue(np.linalg.norm(fake_quantized_tensor - input_tensor) < 0.04)
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import numpy as np
import tensorflow as tf

from mct_quantizers.keras.quantizers.weights_inferable_quantizers.weights_lut_pot_inferable_quantizer import \
WeightsLUTPOTInferableQuantizer
from mct_quantizers.keras.quantizers.weights_inferable_quantizers.weights_lut_symmetric_inferable_quantizer import \
WeightsLUTSymmetricInferableQuantizer

Expand Down Expand Up @@ -205,3 +207,26 @@ def test_weights_pot_lut_quantizer(self):
per_channel=per_channel, channel_axis=channel_axis,
input_rank=input_rank, lut_values_bitwidth=lut_values_bitwidth,
eps=eps)

def test_negative_channel_axis_lut_pot_quantizer(self):
quantizer = WeightsLUTPOTInferableQuantizer(num_bits=8,
lut_values=list(range(-128, 128)),
per_channel=True,
threshold=[1.] * 4,
channel_axis=-2,
input_rank=4)
input_tensor = tf.constant(np.random.rand(2, 3, 4, 5), dtype=tf.float32)
fake_quantized_tensor = quantizer(input_tensor)
self.assertTrue(np.linalg.norm(fake_quantized_tensor - input_tensor) < 0.04)

def test_negative_channel_axis_lut_symmetric_quantizer(self):
quantizer = WeightsLUTSymmetricInferableQuantizer(num_bits=8,
lut_values=list(range(-128, 128)),
per_channel=True,
threshold=[0.99] * 3,
channel_axis=-3,
input_rank=4)
input_tensor = tf.constant(np.random.rand(2, 3, 4, 5), dtype=tf.float32)
fake_quantized_tensor = quantizer(input_tensor)
self.assertTrue(np.linalg.norm(fake_quantized_tensor - input_tensor) < 0.04)

0 comments on commit 7ce4bfd

Please sign in to comment.