diff --git a/wavelets_pytorch/transform.py b/wavelets_pytorch/transform.py index 63d78d5..6de9f2e 100644 --- a/wavelets_pytorch/transform.py +++ b/wavelets_pytorch/transform.py @@ -299,29 +299,29 @@ def cwt(self, x): self.signal_length = signal_length # Move to GPU and perform CWT computation - x = torch.from_numpy(x).type(torch.FloatTensor) - x.requires_grad_(requires_grad=False) + #x = torch.from_numpy(x).type(torch.FloatTensor) + #x.requires_grad_(requires_grad=False) if self._cuda: x = x.cuda() cwt = self._extractor(x) - # Move back to CPU - cwt = cwt.detach() - if self._cuda: cwt = cwt.cpu() - cwt = cwt.numpy() - - if self.complex_wavelet: - # Combine real and imag parts, returns object of shape - # [n_batch,n_scales,signal_length] of type np.complex128 - cwt = (cwt[:,:,0,:] + cwt[:,:,1,:]*1j).astype(self.output_dtype) - else: - # Just squeeze the chn_out dimension (=1) to obtain an object of shape - # [n_batch,n_scales,signal_length] of type np.float64 - cwt = np.squeeze(cwt, 2).astype(self.output_dtype) - - # Squeeze batch dimension if single example - if num_examples == 1: - cwt = cwt.squeeze(0) + ## Move back to CPU + #cwt = cwt.detach() + #if self._cuda: cwt = cwt.cpu() + #cwt = cwt.numpy() + + #if self.complex_wavelet: + # # Combine real and imag parts, returns object of shape + # # [n_batch,n_scales,signal_length] of type np.complex128 + # cwt = (cwt[:,:,0,:] + cwt[:,:,1,:]*1j).astype(self.output_dtype) + #else: + # # Just squeeze the chn_out dimension (=1) to obtain an object of shape + # # [n_batch,n_scales,signal_length] of type np.float64 + # cwt = np.squeeze(cwt, 2).astype(self.output_dtype) + + ## Squeeze batch dimension if single example + #if num_examples == 1: + # cwt = cwt.squeeze(0) return cwt @property @@ -340,4 +340,4 @@ def signal_length(self): @signal_length.setter def signal_length(self, value): super(WaveletTransformTorch, self.__class__).signal_length.fset(self, value) - self._extractor.set_filters(self._filters) \ No newline at end of file + self._extractor.set_filters(self._filters) diff --git a/wavelets_pytorch/wavelets.py b/wavelets_pytorch/wavelets.py index eb09dff..e8cbfbf 100644 --- a/wavelets_pytorch/wavelets.py +++ b/wavelets_pytorch/wavelets.py @@ -21,7 +21,7 @@ import scipy.signal import scipy.optimize import scipy.special -from scipy.misc import factorial +from scipy.special import factorial __all__ = ['Morlet', 'Paul', 'DOG', 'Ricker', 'Marr', 'Mexican_hat'] @@ -380,4 +380,4 @@ def __init__(self): # aliases for DOG2 Marr = Ricker -Mexican_hat = Ricker \ No newline at end of file +Mexican_hat = Ricker