diff --git a/examples/compare_both.py b/examples/compare_both.py index 7c26881..184a971 100644 --- a/examples/compare_both.py +++ b/examples/compare_both.py @@ -18,63 +18,114 @@ import numpy as np import torch +import cv2 -from steerable.SCFpyr import SCFpyr +from steerable.SCFpyr_NumPy import SCFpyr_NumPy from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch -from steerable.visualize import visualize - -import cortex.vision -import cv2 +import steerable.utils as utils ################################################################################ ################################################################################ # Common image_file = './assets/lena.jpg' -im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) -im = cortex.vision.resize(im, out_height=200, out_width=200) -im = im.astype(np.float32)/255. +image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) +image = cv2.resize(image, (200,200)) # Number of pyramid levels pyr_height = 5 -################################################################################ +# Number of orientation bands +pyr_nbands = 4 + +# Tolerance for error checking +tolerance = 1e-3 + ################################################################################ # NumPy -# Build the complex steerable pyramid -pyr = SCFpyr(height=pyr_height) -coeff_numpy = pyr.build(im) +pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2) +coeff_numpy = pyr_numpy.build(image) +reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy) ################################################################################ # PyTorch -device = torch.device('cpu') -batch_size = 16 +device = torch.device('cuda:0') -# Create a batch of images -im_batch = np.tile(im, (batch_size,1,1)) +im_batch = torch.from_numpy(image[None,None,:,:]) +im_batch = im_batch.to(device).float() -# Move to Torch on the GPU -im_batch = torch.from_numpy(im_batch) -im_batch = im_batch.to(device) +pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device) +coeff_torch = pyr_torch.build(im_batch) -# Build the complex steerable pyramid -pyr = SCFpyr_PyTorch(height=pyr_height, scale_factor=2, device=device) -coeff_torch = pyr.build(im_batch) - -# HighPass: coeff[0] : highpass -# BandPass Scale 1: coeff[1][0], coeff[1][1], coeff[1][2], coeff[1][3] -# BandPass Scale 2: coeff[2][0], coeff[2][1], coeff[2][2], coeff[2][3] +# Just extract a single example from the batch +# Also moves the example to CPU and NumPy +coeff_torch = utils.extract_from_batch(coeff_torch, 0) ################################################################################ +# Check correctness + +print('#'*60) +assert len(coeff_numpy) == len(coeff_torch) + +for level, _ in enumerate(coeff_numpy): + + print('Pyramid Level {level}'.format(level=level)) + coeff_level_numpy = coeff_numpy[level] + coeff_level_torch = coeff_torch[level] + + assert type(coeff_level_numpy) == type(coeff_level_torch) + + if isinstance(coeff_level_numpy, np.ndarray): + + # Low- or High-Pass + print(' NumPy. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy), + mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy) + )) + print(' PyTorch. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(coeff_level_torch), max=np.max(coeff_level_torch), + mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch) + )) + + # Check numerical correctness + assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance) + + elif isinstance(coeff_level_numpy, list): + + # Intermediate bands + for band, _ in enumerate(coeff_level_numpy): + + band_numpy = coeff_level_numpy[band] + band_torch = coeff_level_torch[band] + + print(' Orientation Band {}'.format(band)) + print(' NumPy. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(band_numpy), max=np.max(band_numpy), + mean=np.mean(band_numpy), std=np.std(band_numpy) + )) + print(' PyTorch. min = {min:.3f}, max = {max:.3f},' + ' mean = {mean:.3f}, std = {std:.3f}'.format( + min=np.min(band_torch), max=np.max(band_torch), + mean=np.mean(band_torch), std=np.std(band_torch) + )) + + # Check numerical correctness + assert np.allclose(band_numpy, band_torch, atol=tolerance) -bands_viz_numpy = visualize(coeff_numpy) -bands_viz_torch = visualize(coeff_torch, example_idx=0) - -cv2.imshow('NumPy Bands', bands_viz_numpy) -cv2.imshow('PyTorch Bands', bands_viz_torch) -cv2.waitKey(0) +################################################################################ +# Visualize +coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=True) +coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=True) +cv2.imshow('image', image) +cv2.imshow('coeff numpy', coeff_grid_numpy) +cv2.imshow('coeff torch', coeff_grid_torch) +cv2.imshow('reconstruction numpy', reconstruction_numpy) +cv2.waitKey(0) diff --git a/examples/example_numpy.py b/examples/example_numpy.py index 47befb2..cea87f0 100644 --- a/examples/example_numpy.py +++ b/examples/example_numpy.py @@ -16,9 +16,9 @@ from __future__ import division from __future__ import print_function +import time import argparse import numpy as np -import time from steerable.SCFpyr_NumPy import SCFpyr_NumPy import steerable.utils as utils diff --git a/examples/example_numpy_reconstruct.py b/examples/example_numpy_reconstruct.py new file mode 100644 index 0000000..9da739c --- /dev/null +++ b/examples/example_numpy_reconstruct.py @@ -0,0 +1,85 @@ +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-12-04 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import numpy as np +import cv2 + +from steerable.SCFpyr_NumPy import SCFpyr_NumPy +import steerable.utils as utils + +################################################################################ +################################################################################ + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg') + parser.add_argument('--batch_size', type=int, default='1') + parser.add_argument('--image_size', type=int, default='200') + parser.add_argument('--pyr_nlevels', type=int, default='5') + parser.add_argument('--pyr_nbands', type=int, default='4') + parser.add_argument('--pyr_scale_factor', type=int, default='2') + parser.add_argument('--visualize', type=bool, default=True) + config = parser.parse_args() + + ############################################################################ + # Build the complex steerable pyramid + + pyr = SCFpyr_NumPy( + height=config.pyr_nlevels, + nbands=config.pyr_nbands, + scale_factor=config.pyr_scale_factor, + ) + + ############################################################################ + # Create a batch and feed-forward + + image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE) + image = cv2.resize(image, (200,200)) + + # TODO: rescaling to the range [0,1] does not work...? + #image = image.astype(np.float32)/255. + + # Decompose into steerable pyramid + coeff = pyr.build(image) + + # Reconstruct the image from the pyramid coefficients + reconstruction = pyr.reconstruct(coeff) + + reconstruction = reconstruction.astype(np.float32) + reconstruction = np.ascontiguousarray(reconstruction) + reconstruction /= 255. + + ############################################################################ + + tolerance = 1e-4 + print('image', np.mean(image), np.std(image)) + print('reconstruction', np.mean(reconstruction), np.std(reconstruction)) + print('allclose', np.allclose(image, reconstruction, atol=tolerance)) + + ############################################################################ + # Visualization + + if config.visualize: + coeff_grid = utils.make_grid_coeff(coeff, normalize=True) + cv2.imshow('image', image) + cv2.imshow('reconstruction', reconstruction) + cv2.imshow('coeff', coeff_grid) + cv2.waitKey(0) + diff --git a/steerable/SCFpyr_NumPy.py b/steerable/SCFpyr_NumPy.py index 810837b..216d685 100644 --- a/steerable/SCFpyr_NumPy.py +++ b/steerable/SCFpyr_NumPy.py @@ -73,9 +73,9 @@ def build(self, im): assert len(im.shape) == 2, 'Input im must be grayscale' height, width = im.shape - # Check whether im shape allows the pyramid M - max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2) - assert max_height_pyr >= self.height, 'Cannot buid pyramid heigher than {} levels'.format(max_height_pyr) + # Check whether image size is sufficient for number of levels + if self.height > int(np.floor(np.log2(min(width, height))) - 2): + raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) # Prepare a grid log_rad, angle = math_utils.prepare_grid(height, width) @@ -85,7 +85,6 @@ def build(self, im): Yrcos = np.sqrt(Yrcos) YIrcos = np.sqrt(1 - Yrcos**2) - lo0mask = pointOp(log_rad, YIrcos, Xrcos) hi0mask = pointOp(log_rad, Yrcos, Xrcos) @@ -102,7 +101,6 @@ def build(self, im): hi0dft = imdft * hi0mask hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft)) coeff.insert(0, hi0.real) - return coeff @@ -111,7 +109,8 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass - lo0 = np.fft.ifft2(np.fft.ifftshift(lodft)) + lo0 = np.fft.ifftshift(lodft) + lo0 = np.fft.ifft2(lo0) coeff = [lo0.real] else: @@ -188,7 +187,8 @@ def reconstruct(self, coeff): hidft = np.fft.fftshift(np.fft.fft2(coeff[0])) outdft = tempdft * lo0mask + hidft * hi0mask - return np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int) + reconstruction = np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int) + return reconstruction def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): diff --git a/steerable/SCFpyr_PyTorch.py b/steerable/SCFpyr_PyTorch.py index e1b4250..46b38ab 100644 --- a/steerable/SCFpyr_PyTorch.py +++ b/steerable/SCFpyr_PyTorch.py @@ -77,12 +77,12 @@ def build(self, im_batch): assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]' assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image' - height, width = im_batch.shape[2], im_batch.shape[2] im_batch = im_batch.squeeze(1) # flatten channels dim - - # Check whether im shape allows the pyramid M - max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2) - assert max_height_pyr >= self.height, 'Cannot buid pyramid with more than {} levels'.format(max_height_pyr) + height, width = im_batch.shape[2], im_batch.shape[1] + + # Check whether image size is sufficient for number of levels + if self.height > int(np.floor(np.log2(min(width, height))) - 2): + raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height)) # Prepare a grid log_rad, angle = math_utils.prepare_grid(height, width) @@ -112,12 +112,10 @@ def build(self, im_batch): # High-pass hi0dft = batch_dft * hi0mask - - hi0 = torch.ifft(hi0dft, signal_ndim=2) - hi0 = math_utils.batch_ifftshift2d(hi0) + hi0 = math_utils.batch_ifftshift2d(hi0dft) + hi0 = torch.ifft(hi0, signal_ndim=2) hi0_real = torch.unbind(hi0, -1)[0] coeff.insert(0, hi0_real) - return coeff @@ -126,10 +124,9 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): if height <= 1: # Low-pass - lo0 = torch.ifft(lodft, signal_ndim=2) - lo0 = math_utils.batch_fftshift2d(lo0) + lo0 = math_utils.batch_ifftshift2d(lodft) + lo0 = torch.ifft(lo0, signal_ndim=2) lo0_real = torch.unbind(lo0, -1)[0] - # TODO: check correctess of these ops... coeff = [lo0_real] else: @@ -210,15 +207,13 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height): def reconstruct(self, coeff): - raise NotImplementedError('Reconstruction using PyTorch is work in progres...') - if self.nbands != len(coeff[1]): raise Exception("Unmatched number of orientations") - M, N = coeff[0].shape - log_rad, angle = math_utils.utils.prepare_grid(M, N) + height, width = coeff[0].shape[2], coeff[0].shape[1] + log_rad, angle = math_utils.prepare_grid(height, width) - Xrcos, Yrcos = math_utils.utils.rcosFn(1, -0.5) + Xrcos, Yrcos = math_utils.rcosFn(1, -0.5) Yrcos = np.sqrt(Yrcos) YIrcos = np.sqrt(np.abs(1 - Yrcos*Yrcos)) @@ -227,11 +222,14 @@ def reconstruct(self, coeff): tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle) - hidft = np.fft.fftshift(np.fft.fft2(coeff[0])) + hidft = torch.fft(coeff[0], signal_ndim=2) + hidft = math_utils.batch_fftshift2d(hidft) + outdft = tempdft * lo0mask + hidft * hi0mask - return np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int) - + reconstruction = math_utils.batch_fftshift2d(outdft) + reconstruction = torch.ifft(reconstruction, signal_ndim=2).real.astype(int) + return reconstruction def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle): diff --git a/steerable/math_utils.py b/steerable/math_utils.py index 7f5b93d..f7fded7 100644 --- a/steerable/math_utils.py +++ b/steerable/math_utils.py @@ -22,45 +22,26 @@ ################################################################################ ################################################################################ -def batch_flip_halves(x, axis): - split = torch.chunk(x, 2, axis) - return torch.cat((split[1], split[0]), dim=axis) +def roll_n(X, axis, n): + f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) + b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) + front = X[f_idx] + back = X[b_idx] + return torch.cat([back, front], axis) def batch_fftshift2d(x): - assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor' - assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]' - assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' - x = batch_flip_halves(x, axis=1) # top,bottom - x = batch_flip_halves(x, axis=2) # left,right - return x + real, imag = torch.unbind(x, -1) + for dim in range(1, len(real.size())): + real = roll_n(real, axis=dim, n=real.size(dim)//2) + imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) + return torch.stack((real, imag), -1) # last dim=2 (real&imag) def batch_ifftshift2d(x): - ndim = x.dim() - assert ndim == 4, 'input tensor must be of shape [N,H,W,2]' - assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' - x = batch_flip_halves(x, axis=2) # left,right - x = batch_flip_halves(x, axis=1) # top,bottom - return x - -# def roll_n(X, axis, n): -# # Source: https://github.com/locuslab/pytorch_fft/blob/master/pytorch_fft/fft/fft.py#L230 -# f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim())) -# b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim())) -# front = X[f_idx] -# back = X[b_idx] -# return torch.cat([back, front], axis) - -# def fftshift(real, imag): -# for dim in range(1, len(real.size())): -# real = roll_n(real, axis=dim, n=real.size(dim)//2) -# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) -# return torch.stack((real, imag), -1) # last dim=2 (real/imag) - -# def ifftshift(real, imag): -# for dim in range(len(real.size()) - 1, 0, -1): -# real = roll_n(real, axis=dim, n=real.size(dim)//2) -# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) -# return torch.stack((real, imag), -1) # last dim=2 (real/imag) + real, imag = torch.unbind(x, -1) + for dim in range(len(real.size()) - 1, 0, -1): + real = roll_n(real, axis=dim, n=real.size(dim)//2) + imag = roll_n(imag, axis=dim, n=imag.size(dim)//2) + return torch.stack((real, imag), -1) # last dim=2 (real&imag) ################################################################################ ################################################################################ @@ -92,3 +73,26 @@ def getlist(coeff): straight = [bands for scale in coeff[1:-1] for bands in scale] straight = [coeff[0]] + straight + [coeff[-1]] return straight + +################################################################################ +# Alternative fftshift implementation + +# def batch_flip_halves(x, axis): +# split = torch.chunk(x, 2, axis) +# return torch.cat((split[1], split[0]), dim=axis) + +# def batch_fftshift2d_v1(x): +# assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor' +# assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]' +# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' +# x = batch_flip_halves(x, axis=1) # top,bottom +# x = batch_flip_halves(x, axis=2) # left,right +# return x + +# def batch_ifftshift2d_v1(x): +# ndim = x.dim() +# assert ndim == 4, 'input tensor must be of shape [N,H,W,2]' +# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]' +# x = batch_flip_halves(x, axis=2) # left,right +# x = batch_flip_halves(x, axis=1) # top,bottom +# return x \ No newline at end of file diff --git a/tests/test_ifft.py b/tests/test_ifft.py index 1b08e3f..aa003b8 100644 --- a/tests/test_ifft.py +++ b/tests/test_ifft.py @@ -83,8 +83,6 @@ ################################################################################ # Tolerance checking - - all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance) all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance) print('fft allclose real: {}'.format(all_close_real))