Skip to content

Commit

Permalink
Fixing quantization transform (#207)
Browse files Browse the repository at this point in the history
* quantization now performed on real and imaginary levels independently, rather than the complex magnitude

* Added eps

* Removing bias.

* Fixing quantize

* Satisfying MyPy

* Satisfying MyPy

---------

Co-authored-by: Garrett Vanhoy <gmvanhoy@gmail.com>
  • Loading branch information
MattCarrickPL and gvanhoy authored Sep 18, 2023
1 parent 877568e commit 57ec69f
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
21 changes: 11 additions & 10 deletions torchsig/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1026,29 +1026,30 @@ def quantize(
"""
# Setup quantization resolution/bins
max_value = max(np.abs(tensor)) + 1e-9
real_max = np.max(np.abs(tensor.real))
imag_max = np.max(np.abs(tensor.imag))

max_value = max(real_max, imag_max)
bins = np.linspace(-max_value, max_value, num_levels + 1)

# Digitize to bins
quantized_real = np.digitize(tensor.real, bins)
quantized_imag = np.digitize(tensor.imag, bins)
bins_real = np.digitize(tensor.real, bins)
bins_imag = np.digitize(tensor.imag, bins)

if round_type == "floor":
quantized_real -= 1
quantized_imag -= 1
bins_real -= 1
bins_imag -= 1

# Revert to values
quantized_real = bins[quantized_real]
quantized_imag = bins[quantized_imag]
quantized_real = bins[bins_real]
quantized_imag = bins[bins_imag]

if round_type == "nearest":
bin_size = np.diff(bins)[0]
quantized_real -= bin_size / 2
quantized_imag -= bin_size / 2

quantized_tensor = quantized_real + 1j * quantized_imag

return quantized_tensor
return quantized_real + 1j * quantized_imag


def clip(tensor: np.ndarray, clip_percentage: float) -> np.ndarray:
Expand Down
5 changes: 3 additions & 2 deletions torchsig/utils/cm_plotter.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from sklearn.metrics import confusion_matrix
from matplotlib import pyplot as plt
from matplotlib.colors import Colormap
from typing import Optional
import numpy as np

Expand All @@ -13,7 +14,7 @@ def plot_confusion_matrix(
text: bool = True,
rotate_x_text: int = 90,
figsize: tuple = (16, 9),
cmap: plt.cm = plt.cm.Blues,
cmap: str = "Blues",
):
"""Function to help plot confusion matrices
Expand Down Expand Up @@ -60,7 +61,7 @@ def plot_confusion_matrix(
color="white" if cm[i, j] > thresh else "black",
)
if len(classes) == 2:
plt.axis([-0.5, 1.5, 1.5, -0.5])
plt.axis((-0.5, 1.5, 1.5, -0.5))
fig.tight_layout()

return ax
6 changes: 3 additions & 3 deletions torchsig/utils/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def _visualize(self, iq_data: np.ndarray, targets: np.ndarray) -> Figure:
ts = np.arange(len(cwt_matrix[0])) / self.sample_rate
plt.imshow(
np.abs(cwt_matrix),
extent=[ts[0], ts[-1], freqs[-1], freqs[0]],
extent=[ts[0], ts[-1], freqs[-1], freqs[0]], # type: ignore
vmin=0,
vmax=np.abs(cwt_matrix).max(),
aspect="auto",
Expand Down Expand Up @@ -458,7 +458,7 @@ def _visualize( # type: ignore
title = "Data"
plt.xticks([])
plt.yticks([])
plt.title(title)
plt.title(str(title))

return figure

Expand Down Expand Up @@ -521,7 +521,7 @@ def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure:
title = "Data"
plt.xticks([])
plt.yticks([])
plt.title(title)
plt.title(str(title))

return figure

Expand Down

0 comments on commit 57ec69f

Please sign in to comment.