From 57ec69f8adee60e22ed03ac3157482878ef73202 Mon Sep 17 00:00:00 2001 From: MattCarrickPL <120057274+MattCarrickPL@users.noreply.github.com> Date: Mon, 18 Sep 2023 11:10:56 -0400 Subject: [PATCH] Fixing quantization transform (#207) * quantization now performed on real and imaginary levels independently, rather than the complex magnitude * Added eps * Removing bias. * Fixing quantize * Satisfying MyPy * Satisfying MyPy --------- Co-authored-by: Garrett Vanhoy --- torchsig/transforms/functional.py | 21 +++++++++++---------- torchsig/utils/cm_plotter.py | 5 +++-- torchsig/utils/visualize.py | 6 +++--- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/torchsig/transforms/functional.py b/torchsig/transforms/functional.py index b5564d8..9316382 100644 --- a/torchsig/transforms/functional.py +++ b/torchsig/transforms/functional.py @@ -1026,29 +1026,30 @@ def quantize( """ # Setup quantization resolution/bins - max_value = max(np.abs(tensor)) + 1e-9 + real_max = np.max(np.abs(tensor.real)) + imag_max = np.max(np.abs(tensor.imag)) + + max_value = max(real_max, imag_max) bins = np.linspace(-max_value, max_value, num_levels + 1) # Digitize to bins - quantized_real = np.digitize(tensor.real, bins) - quantized_imag = np.digitize(tensor.imag, bins) + bins_real = np.digitize(tensor.real, bins) + bins_imag = np.digitize(tensor.imag, bins) if round_type == "floor": - quantized_real -= 1 - quantized_imag -= 1 + bins_real -= 1 + bins_imag -= 1 # Revert to values - quantized_real = bins[quantized_real] - quantized_imag = bins[quantized_imag] + quantized_real = bins[bins_real] + quantized_imag = bins[bins_imag] if round_type == "nearest": bin_size = np.diff(bins)[0] quantized_real -= bin_size / 2 quantized_imag -= bin_size / 2 - quantized_tensor = quantized_real + 1j * quantized_imag - - return quantized_tensor + return quantized_real + 1j * quantized_imag def clip(tensor: np.ndarray, clip_percentage: float) -> np.ndarray: diff --git a/torchsig/utils/cm_plotter.py b/torchsig/utils/cm_plotter.py index 75969b7..ddea63e 100644 --- a/torchsig/utils/cm_plotter.py +++ b/torchsig/utils/cm_plotter.py @@ -1,5 +1,6 @@ from sklearn.metrics import confusion_matrix from matplotlib import pyplot as plt +from matplotlib.colors import Colormap from typing import Optional import numpy as np @@ -13,7 +14,7 @@ def plot_confusion_matrix( text: bool = True, rotate_x_text: int = 90, figsize: tuple = (16, 9), - cmap: plt.cm = plt.cm.Blues, + cmap: str = "Blues", ): """Function to help plot confusion matrices @@ -60,7 +61,7 @@ def plot_confusion_matrix( color="white" if cm[i, j] > thresh else "black", ) if len(classes) == 2: - plt.axis([-0.5, 1.5, 1.5, -0.5]) + plt.axis((-0.5, 1.5, 1.5, -0.5)) fig.tight_layout() return ax diff --git a/torchsig/utils/visualize.py b/torchsig/utils/visualize.py index b4512aa..dede717 100644 --- a/torchsig/utils/visualize.py +++ b/torchsig/utils/visualize.py @@ -176,7 +176,7 @@ def _visualize(self, iq_data: np.ndarray, targets: np.ndarray) -> Figure: ts = np.arange(len(cwt_matrix[0])) / self.sample_rate plt.imshow( np.abs(cwt_matrix), - extent=[ts[0], ts[-1], freqs[-1], freqs[0]], + extent=[ts[0], ts[-1], freqs[-1], freqs[0]], # type: ignore vmin=0, vmax=np.abs(cwt_matrix).max(), aspect="auto", @@ -458,7 +458,7 @@ def _visualize( # type: ignore title = "Data" plt.xticks([]) plt.yticks([]) - plt.title(title) + plt.title(str(title)) return figure @@ -521,7 +521,7 @@ def _visualize(self, data: np.ndarray, targets: np.ndarray) -> Figure: title = "Data" plt.xticks([]) plt.yticks([]) - plt.title(title) + plt.title(str(title)) return figure