Skip to content

Commit

Permalink
Rework of the RollOff() roll_off() transform (#209)
Browse files Browse the repository at this point in the history
* designing a filter for the roll_off() transform

* pulling roll_off_filter() into transforms

* updating wideband.py to use the new roll_off() transform

* converting over tests to use the updated roll-off function

* Adjusting coding style

---------

Co-authored-by: Garrett Vanhoy <gmvanhoy@gmail.com>
  • Loading branch information
MattCarrickPL and gvanhoy authored Sep 18, 2023
1 parent 57ec69f commit 8997b84
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 63 deletions.
2 changes: 1 addition & 1 deletion tests/test_transforms_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_data():
IQImbalance(3, np.pi / 180, 0.05),
IQImbalance(3, np.pi / 180, 0.05),
),
("roll_off", RollOff(0.05, 0.98), RollOff(0.05, 0.98)),
("roll_off", RollOff(0.25, 0.1), RollOff(0.25, 0.1)),
("add_slope", AddSlope(), AddSlope()),
("spectral_inversion", SpectralInversion(), SpectralInversion()),
("channel_swap", ChannelSwap(), ChannelSwap()),
Expand Down
2 changes: 1 addition & 1 deletion tests/test_transforms_figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def generate_data(modulation_name):
IQImbalance(3, np.pi / 180, 0.05),
IQImbalance(3, np.pi / 180, 0.05),
),
("roll_off", RollOff(0.05, 0.98), RollOff(0.05, 0.98)),
("roll_off", RollOff(0.25, 0.1), RollOff(0.25, 0.1)),
("add_slope", AddSlope(), AddSlope()),
("spectral_inversion", SpectralInversion(), SpectralInversion()),
("channel_swap", ChannelSwap(), ChannelSwap()),
Expand Down
7 changes: 2 additions & 5 deletions torchsig/datasets/wideband.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,11 +969,8 @@ def __init__(
RandomMagRescale(start=(0, 0.9), scale=(-4, 4)), 0.5
),
RollOff(
low_freq=(0.00, 0.05),
upper_freq=(0.95, 1.00),
low_cut_apply=0.5,
upper_cut_apply=0.5,
order=(6, 20),
cutoff=(0.25, 0.5),
cfo=(-0.1, 0.1),
),
RandomConvolve(num_taps=(2, 5), alpha=(0.1, 0.4)),
RayleighFadingChannel((0.001, 0.01)),
Expand Down
36 changes: 11 additions & 25 deletions torchsig/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from numba import complex64, float64, int64, njit
from torchsig.utils.types import RandomDistribution
from torchsig.utils.dsp import low_pass
from torchsig.utils.dsp import roll_off_filter
from scipy import interpolate
from scipy import signal as sp
import numpy as np
Expand Down Expand Up @@ -866,45 +867,30 @@ def amplitude_reversal(tensor: np.ndarray) -> np.ndarray:

def roll_off(
tensor: np.ndarray,
lowercutfreq: float,
uppercutfreq: float,
fltorder: int,
cutoff: float,
cfo: float,
) -> np.ndarray:
"""Applies front-end filter to tensor. Rolls off lower/upper edges of bandwidth
Args:
tensor: (:class:`numpy.ndarray`):
(batch_size, vector_length, ...)-sized tensor.
lowercutfreq (:obj:`float`):
lower bandwidth cut-off to begin linear roll-off
cutoff (:obj:`float`):
cutoff frequency for the roll-off filter, within 0.25 to 0.5 (representing fs/4 to fs/2)
uppercutfreq (:obj:`float`):
upper bandwidth cut-off to begin linear roll-off
fltorder (:obj:`int`):
order of each FIR filter to be applied
cfo (:obj:`float`):
center frequency offset (CFO) for the filter, within -0.1 to 0.1 (representing -fs/10 to fs/10)
Returns:
transformed (:class:`numpy.ndarray`):
Tensor that has undergone front-end filtering.
"""
if (lowercutfreq == 0) & (uppercutfreq == 1):
return tensor

elif uppercutfreq == 1:
if fltorder % 2 == 0:
fltorder += 1
bandwidth = uppercutfreq - lowercutfreq
center_freq = lowercutfreq - 0.5 + bandwidth / 2
taps = low_pass(
cutoff=bandwidth / 2, transition_bandwidth=(0.5 - bandwidth / 2) / 4
)
sinusoid = np.exp(
2j * np.pi * center_freq * np.linspace(0, len(taps) - 1, len(taps))
)
taps = taps * sinusoid

# design the roll-off filter
taps = roll_off_filter ( cutoff, cfo )

return sp.convolve(tensor, taps, mode="same")


Expand Down
41 changes: 10 additions & 31 deletions torchsig/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2652,58 +2652,37 @@ class RollOff(SignalTransform):
"""Applies a band-edge RF roll-off effect simulating front end filtering
Args:
low_freq (:py:class:`~torchsig.types.RandomDistribution`):
cutoff (:py:class:`~torchsig.types.RandomDistribution`):
upper_freq (:py:class:`~torchsig.types.RandomDistribution`):
low_cut_apply (:obj:`float`):
Probability that the low frequency provided above is applied
upper_cut_apply (:obj:`float`):
Probability that the upper frequency provided above is applied
order (:py:class:`~torchsig.types.RandomDistribution`):
cfo (:py:class:`~torchsig.types.RandomDistribution`):
"""

def __init__(
self,
low_freq: FloatParameter = UniformContinuousRD(0.00, 0.05),
upper_freq: FloatParameter = UniformContinuousRD(0.95, 1.00),
low_cut_apply: float = 0.5,
upper_cut_apply: float = 0.5,
order: FloatParameter = UniformContinuousRD(6, 20),
cutoff: FloatParameter = UniformContinuousRD(0.25, 0.5),
cfo: FloatParameter = UniformContinuousRD(-0.1, 0.1),
) -> None:
super(RollOff, self).__init__()
self.low_freq = RandomDistribution.to_distribution(low_freq)
self.upper_freq = RandomDistribution.to_distribution(upper_freq)
self.low_cut_apply = low_cut_apply
self.upper_cut_apply = upper_cut_apply
self.order = RandomDistribution.to_distribution(order)
self.cutoff = RandomDistribution.to_distribution(cutoff)
self.cfo = RandomDistribution.to_distribution(cfo)
self.string = (
self.__class__.__name__
+ "("
+ "low_freq={}, ".format(low_freq)
+ "upper_freq={}, ".format(upper_freq)
+ "upper_cut_apply={}, ".format(upper_cut_apply)
+ "order={}".format(order)
+ "cutoff={}, ".format(cutoff)
+ "cfo={}, ".format(cfo)
+ ")"
)

def __repr__(self) -> str:
return self.string

def __call__(self, data: Any) -> Any:
low_freq = self.low_freq() if np.random.rand() < self.low_cut_apply else 0.0
upper_freq = (
self.upper_freq() if np.random.rand() < self.upper_cut_apply else 1.0
)
order = self.order()
if isinstance(data, SignalData):
assert data.iq_data is not None
data.iq_data = F.roll_off(data.iq_data, low_freq, upper_freq, int(order))
data.iq_data = F.roll_off(data.iq_data, self.cutoff(), self.cfo())
else:
data = F.roll_off(data, low_freq, upper_freq, int(order))
data = F.roll_off(data, self.cutoff(), self.cfo())
return data


Expand Down
27 changes: 27 additions & 0 deletions torchsig/utils/dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,33 @@
import numpy as np


def roll_off_filter(cutoff: float, cfo: float):
"""Designs a filter to apply a randomized roll-off factor for the roll_off() impairment. When the parameters
are within the specified ranges, the roll-off filter will provide a slight LPF effect with the attenuation
at -fs/2 and +fs/2 within 0 dB to -12 dB gain.
Args:
cutoff (float): filter cutoff-frequency, from 0.25 to 0.5 (fs/4 to fs/2)
cfo (float): center frequency offset, from -0.1 to 0.1 (-fs/10 to fs/10)
"""
# design the time indexing
half_len = 2
filt_order = np.arange(-half_len, half_len + 1)

# compute the sinc LPF
sinc_lpf = np.sinc(2 * filt_order * cutoff)

# calculate the Bartlett taper which removes the nulls in the LPF
window = sp.windows.bartlett(filt_order.shape[0])

# compute the frequency shifter
freq_shift = np.exp(2j * np.pi * cfo * filt_order)

taps = sinc_lpf * window * freq_shift
return taps


def convolve(signal: np.ndarray, taps: np.ndarray) -> np.ndarray:
return sp.convolve(signal, taps, "same")

Expand Down

0 comments on commit 8997b84

Please sign in to comment.