diff --git a/examples/compare_both.py b/examples/compare_both.py index 7c26881..184a971 100644 --- a/examples/compare_both.py +++ b/examples/compare_both.py @@ -18,63 +18,114 @@ import numpy as np import torch +import cv2 -from steerable.SCFpyr import SCFpyr +from steerable.SCFpyr_NumPy import SCFpyr_NumPy from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch -from steerable.visualize import visualize - -import cortex.vision -import cv2 +import steerable.utils as utils ################################################################################ ################################################################################ # Common image_file = './assets/lena.jpg' -im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) -im = cortex.vision.resize(im, out_height=200, out_width=200) -im = im.astype(np.float32)/255. +image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) +image = cv2.resize(image, (200,200)) # Number of pyramid levels pyr_height = 5 -################################################################################ +# Number of orientation bands +pyr_nbands = 4 + +# Tolerance for error checking +tolerance = 1e-3 + ################################################################################ # NumPy -# Build the complex steerable pyramid -pyr = SCFpyr(height=pyr_height) -coeff_numpy = pyr.build(im) +pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2) +coeff_numpy = pyr_numpy.build(image) +reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy) ################################################################################ # PyTorch -device = torch.device('cpu') -batch_size = 16 +device = torch.device('cuda:0') -# Create a batch of images -im_batch = np.tile(im, (batch_size,1,1)) +im_batch = torch.from_numpy(image[None,None,:,:]) +im_batch = im_batch.to(device).float() -# Move to Torch on the GPU -im_batch = torch.from_numpy(im_batch) -im_batch = im_batch.to(device) +pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device) +coeff_torch = pyr_torch.build(im_batch) -# Build the complex steerable pyramid -pyr = SCFpyr_PyTorch(height=pyr_height, scale_factor=2, device=device) -coeff_torch = pyr.build(im_batch) - -# HighPass: coeff[0] : highpass -# BandPass Scale 1: coeff[1][0], coeff[1][1], coeff[1][2], coeff[1][3] -# BandPass Scale 2: coeff[2][0], coeff[2][1], coeff[2][2], coeff[2][3] +# Just extract a single example from the batch +# Also moves the example to CPU and NumPy +coeff_torch = utils.extract_from_batch(coeff_torch, 0) ################################################################################ +# Check correctness + +print('#'*60) +assert len(coeff_numpy) == len(coeff_torch) + +for level, _ in enumerate(coeff_numpy): + + print('Pyramid Level {level}'.format(level=level)) + coeff_level_numpy = coeff_numpy[level] + coeff_level_torch = coeff_torch[level] + + assert type(coeff_level_numpy) == type(coeff_level_torch) + + if isinstance(coeff_level_numpy, np.ndarray): + + # Low- or High-Pass + print(' NumPy. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy), + mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy) + )) + print(' PyTorch. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(coeff_level_torch), max=np.max(coeff_level_torch), + mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch) + )) + + # Check numerical correctness + assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance) + + elif isinstance(coeff_level_numpy, list): + + # Intermediate bands + for band, _ in enumerate(coeff_level_numpy): + + band_numpy = coeff_level_numpy[band] + band_torch = coeff_level_torch[band] + + print(' Orientation Band {}'.format(band)) + print(' NumPy. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(band_numpy), max=np.max(band_numpy), + mean=np.mean(band_numpy), std=np.std(band_numpy) + )) + print(' PyTorch. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(band_torch), max=np.max(band_torch), + mean=np.mean(band_torch), std=np.std(band_torch) + )) + + # Check numerical correctness + assert np.allclose(band_numpy, band_torch, atol=tolerance) -bands_viz_numpy = visualize(coeff_numpy) -bands_viz_torch = visualize(coeff_torch, example_idx=0) - -cv2.imshow('NumPy Bands', bands_viz_numpy) -cv2.imshow('PyTorch Bands', bands_viz_torch) -cv2.waitKey(0) +################################################################################ +# Visualize +coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=True) +coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=True) +cv2.imshow('image', image) +cv2.imshow('coeff numpy', coeff_grid_numpy) +cv2.imshow('coeff torch', coeff_grid_torch) +cv2.imshow('reconstruction numpy', reconstruction_numpy) +cv2.waitKey(0) diff --git a/examples/example_numpy_reconstruct.py b/examples/example_numpy_reconstruct.py index 3cb0853..9da739c 100644 --- a/examples/example_numpy_reconstruct.py +++ b/examples/example_numpy_reconstruct.py @@ -16,7 +16,6 @@ from __future__ import division from __future__ import print_function -import time import argparse import numpy as np import cv2 @@ -53,6 +52,8 @@ image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE) image = cv2.resize(image, (200,200)) + + # TODO: rescaling to the range [0,1] does not work...? #image = image.astype(np.float32)/255. # Decompose into steerable pyramid diff --git a/steerable/SCFpyr_NumPy.py b/steerable/SCFpyr_NumPy.py index fed5527..216d685 100644 --- a/steerable/SCFpyr_NumPy.py +++ b/steerable/SCFpyr_NumPy.py @@ -73,7 +73,7 @@ def build(self, im): assert len(im.shape) == 2, 'Input im must be grayscale' height, width = im.shape - # Check whether im shape allows the pyramid M + # Check whether image size is sufficient for number of levels if self.height > int(np.floor(np.log2(min(width, height))) - 2): raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) @@ -101,7 +101,6 @@ def build(self, im): hi0dft = imdft * hi0mask hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft)) coeff.insert(0, hi0.real) - return coeff @@ -110,7 +109,8 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass - lo0 = np.fft.ifft2(np.fft.ifftshift(lodft)) + lo0 = np.fft.ifftshift(lodft) + lo0 = np.fft.ifft2(lo0) coeff = [lo0.real] else: diff --git a/steerable/SCFpyr_PyTorch.py b/steerable/SCFpyr_PyTorch.py index e358615..bb55286 100644 --- a/steerable/SCFpyr_PyTorch.py +++ b/steerable/SCFpyr_PyTorch.py @@ -77,12 +77,12 @@ def build(self, im_batch): assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image' - height, width = im_batch.shape[2], im_batch.shape[1] im_batch = im_batch.squeeze(1) # flatten channels dim - - # Check whether im shape allows the pyramid M - max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2) - assert max_height_pyr >= self.height, 'Cannot buid pyramid with more than {} levels'.format(max_height_pyr) + height, width = im_batch.shape[2], im_batch.shape[1] + + # Check whether image size is sufficient for number of levels + if self.height > int(np.floor(np.log2(min(width, height))) - 2): + raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) # Prepare a grid log_rad, angle = math_utils.prepare_grid(height, width) @@ -112,12 +112,10 @@ def build(self, im_batch): # High-pass hi0dft = batch_dft * hi0mask - - hi0 = torch.ifft(hi0dft, signal_ndim=2) - hi0 = math_utils.batch_ifftshift2d(hi0) + hi0 = math_utils.batch_ifftshift2d(hi0dft) + hi0 = torch.ifft(hi0, signal_ndim=2) hi0_real = torch.unbind(hi0, -1)[0] coeff.insert(0, hi0_real) - return coeff @@ -126,10 +124,9 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass - lo0 = torch.ifft(lodft, signal_ndim=2) - lo0 = math_utils.batch_fftshift2d(lo0) + lo0 = math_utils.batch_ifftshift2d(lodft) + lo0 = torch.ifft(lo0, signal_ndim=2) lo0_real = torch.unbind(lo0, -1)[0] - # TODO: check correctess of these ops... coeff = [lo0_real] else: @@ -225,13 +222,13 @@ def reconstruct(self, coeff): tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) - hidft = torch.fft(coeff[0]) + hidft = torch.fft(coeff[0], signal_ndim=2) hidft = math_utils.batch_fftshift2d(hidft) outdft = tempdft * lo0mask + hidft * hi0mask reconstruction = math_utils.batch_fftshift2d(outdft) - reconstruction = torch.ifft(reconstruction).real.astype(int) + reconstruction = torch.ifft(reconstruction, signal_ndim=2).real.astype(int) return reconstruction def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): diff --git a/steerable/math_utils.py b/steerable/math_utils.py index 7f5b93d..f7fded7 100644 --- a/steerable/math_utils.py +++ b/steerable/math_utils.py @@ -22,45 +22,26 @@ ################################################################################ ################################################################################ -def batch_flip_halves(x, axis): - split = torch.chunk(x, 2, axis) - return torch.cat((split[1], split[0]), dim=axis) +def roll_n(X, axis, n): + f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) + b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) + front = X[f_idx] + back = X[b_idx] + return torch.cat([back, front], axis) def batch_fftshift2d(x): - assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor' - assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]' - assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' - x = batch_flip_halves(x, axis=1) # top,bottom - x = batch_flip_halves(x, axis=2) # left,right - return x + real, imag = torch.unbind(x, -1) + for dim in range(1, len(real.size())): + real = roll_n(real, axis=dim, n=real.size(dim)//2) + imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) + return torch.stack((real, imag), -1) # last dim=2 (real&imag) def batch_ifftshift2d(x): - ndim = x.dim() - assert ndim == 4, 'input tensor must be of shape [N,H,W,2]' - assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' - x = batch_flip_halves(x, axis=2) # left,right - x = batch_flip_halves(x, axis=1) # top,bottom - return x - -# def roll_n(X, axis, n): -# # Source: https://github.com/locuslab/pytorch_fft/blob/master/pytorch_fft/fft/fft.py#L230 -# f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) -# b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) -# front = X[f_idx] -# back = X[b_idx] -# return torch.cat([back, front], axis) - -# def fftshift(real, imag): -# for dim in range(1, len(real.size())): -# real = roll_n(real, axis=dim, n=real.size(dim)//2) -# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) -# return torch.stack((real, imag), -1) # last dim=2 (real/imag) - -# def ifftshift(real, imag): -# for dim in range(len(real.size()) - 1, 0, -1): -# real = roll_n(real, axis=dim, n=real.size(dim)//2) -# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) -# return torch.stack((real, imag), -1) # last dim=2 (real/imag) + real, imag = torch.unbind(x, -1) + for dim in range(len(real.size()) - 1, 0, -1): + real = roll_n(real, axis=dim, n=real.size(dim)//2) + imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) + return torch.stack((real, imag), -1) # last dim=2 (real&imag) ################################################################################ ################################################################################ @@ -92,3 +73,26 @@ def getlist(coeff): straight = [bands for scale in coeff[1:-1] for bands in scale] straight = [coeff[0]] + straight + [coeff[-1]] return straight + +################################################################################ +# Alternative fftshift implementation + +# def batch_flip_halves(x, axis): +# split = torch.chunk(x, 2, axis) +# return torch.cat((split[1], split[0]), dim=axis) + +# def batch_fftshift2d_v1(x): +# assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor' +# assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]' +# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' +# x = batch_flip_halves(x, axis=1) # top,bottom +# x = batch_flip_halves(x, axis=2) # left,right +# return x + +# def batch_ifftshift2d_v1(x): +# ndim = x.dim() +# assert ndim == 4, 'input tensor must be of shape [N,H,W,2]' +# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' +# x = batch_flip_halves(x, axis=2) # left,right +# x = batch_flip_halves(x, axis=1) # top,bottom +# return x \ No newline at end of file diff --git a/tests/test_ifft.py b/tests/test_ifft.py index 1b08e3f..aa003b8 100644 --- a/tests/test_ifft.py +++ b/tests/test_ifft.py @@ -83,8 +83,6 @@ ################################################################################ # Tolerance checking - - all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) print('fft allclose real: {}'.format(all_close_real))