Skip to content

Commit

Permalink
bugfix in low-pass
Browse files Browse the repository at this point in the history
  • Loading branch information
tomrunia committed Dec 11, 2018
1 parent 98da898 commit eb10b14
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 87 deletions.
115 changes: 83 additions & 32 deletions examples/compare_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,114 @@

import numpy as np
import torch
import cv2

from steerable.SCFpyr import SCFpyr
from steerable.SCFpyr_NumPy import SCFpyr_NumPy
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from steerable.visualize import visualize

import cortex.vision
import cv2
import steerable.utils as utils

################################################################################
################################################################################
# Common

image_file = './assets/lena.jpg'
im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
im = cortex.vision.resize(im, out_height=200, out_width=200)
im = im.astype(np.float32)/255.
image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (200,200))

# Number of pyramid levels
pyr_height = 5

################################################################################
# Number of orientation bands
pyr_nbands = 4

# Tolerance for error checking
tolerance = 1e-3

################################################################################
# NumPy

# Build the complex steerable pyramid
pyr = SCFpyr(height=pyr_height)
coeff_numpy = pyr.build(im)
pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
coeff_numpy = pyr_numpy.build(image)
reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)

################################################################################
# PyTorch

device = torch.device('cpu')
batch_size = 16
device = torch.device('cuda:0')

# Create a batch of images
im_batch = np.tile(im, (batch_size,1,1))
im_batch = torch.from_numpy(image[None,None,:,:])
im_batch = im_batch.to(device).float()

# Move to Torch on the GPU
im_batch = torch.from_numpy(im_batch)
im_batch = im_batch.to(device)
pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)

# Build the complex steerable pyramid
pyr = SCFpyr_PyTorch(height=pyr_height, scale_factor=2, device=device)
coeff_torch = pyr.build(im_batch)

# HighPass: coeff[0] : highpass
# BandPass Scale 1: coeff[1][0], coeff[1][1], coeff[1][2], coeff[1][3]
# BandPass Scale 2: coeff[2][0], coeff[2][1], coeff[2][2], coeff[2][3]
# Just extract a single example from the batch
# Also moves the example to CPU and NumPy
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

################################################################################
# Check correctness

print('#'*60)
assert len(coeff_numpy) == len(coeff_torch)

for level, _ in enumerate(coeff_numpy):

print('Pyramid Level {level}'.format(level=level))
coeff_level_numpy = coeff_numpy[level]
coeff_level_torch = coeff_torch[level]

assert type(coeff_level_numpy) == type(coeff_level_torch)

if isinstance(coeff_level_numpy, np.ndarray):

# Low- or High-Pass
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)
))
print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)
))

# Check numerical correctness
assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance)

elif isinstance(coeff_level_numpy, list):

# Intermediate bands
for band, _ in enumerate(coeff_level_numpy):

band_numpy = coeff_level_numpy[band]
band_torch = coeff_level_torch[band]

print(' Orientation Band {}'.format(band))
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_numpy), max=np.max(band_numpy),
mean=np.mean(band_numpy), std=np.std(band_numpy)
))
print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_torch), max=np.max(band_torch),
mean=np.mean(band_torch), std=np.std(band_torch)
))

# Check numerical correctness
assert np.allclose(band_numpy, band_torch, atol=tolerance)

bands_viz_numpy = visualize(coeff_numpy)
bands_viz_torch = visualize(coeff_torch, example_idx=0)

cv2.imshow('NumPy Bands', bands_viz_numpy)
cv2.imshow('PyTorch Bands', bands_viz_torch)
cv2.waitKey(0)
################################################################################
# Visualize

coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=True)
coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=True)

cv2.imshow('image', image)
cv2.imshow('coeff numpy', coeff_grid_numpy)
cv2.imshow('coeff torch', coeff_grid_torch)
cv2.imshow('reconstruction numpy', reconstruction_numpy)

cv2.waitKey(0)
3 changes: 2 additions & 1 deletion examples/example_numpy_reconstruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np
import cv2
Expand Down Expand Up @@ -53,6 +52,8 @@

image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (200,200))

# TODO: rescaling to the range [0,1] does not work...?
#image = image.astype(np.float32)/255.

# Decompose into steerable pyramid
Expand Down
6 changes: 3 additions & 3 deletions steerable/SCFpyr_NumPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def build(self, im):
assert len(im.shape) == 2, 'Input im must be grayscale'
height, width = im.shape

# Check whether im shape allows the pyramid M
# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))

Expand Down Expand Up @@ -101,7 +101,6 @@ def build(self, im):
hi0dft = imdft * hi0mask
hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft))
coeff.insert(0, hi0.real)

return coeff


Expand All @@ -110,7 +109,8 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
if height <= 1:

# Low-pass
lo0 = np.fft.ifft2(np.fft.ifftshift(lodft))
lo0 = np.fft.ifftshift(lodft)
lo0 = np.fft.ifft2(lo0)
coeff = [lo0.real]

else:
Expand Down
25 changes: 11 additions & 14 deletions steerable/SCFpyr_PyTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def build(self, im_batch):
assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'

height, width = im_batch.shape[2], im_batch.shape[1]
im_batch = im_batch.squeeze(1) # flatten channels dim

# Check whether im shape allows the pyramid M
max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2)
assert max_height_pyr >= self.height, 'Cannot buid pyramid with more than {} levels'.format(max_height_pyr)
height, width = im_batch.shape[2], im_batch.shape[1]

# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))

# Prepare a grid
log_rad, angle = math_utils.prepare_grid(height, width)
Expand Down Expand Up @@ -112,12 +112,10 @@ def build(self, im_batch):

# High-pass
hi0dft = batch_dft * hi0mask

hi0 = torch.ifft(hi0dft, signal_ndim=2)
hi0 = math_utils.batch_ifftshift2d(hi0)
hi0 = math_utils.batch_ifftshift2d(hi0dft)
hi0 = torch.ifft(hi0, signal_ndim=2)
hi0_real = torch.unbind(hi0, -1)[0]
coeff.insert(0, hi0_real)

return coeff


Expand All @@ -126,10 +124,9 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
if height <= 1:

# Low-pass
lo0 = torch.ifft(lodft, signal_ndim=2)
lo0 = math_utils.batch_fftshift2d(lo0)
lo0 = math_utils.batch_ifftshift2d(lodft)
lo0 = torch.ifft(lo0, signal_ndim=2)
lo0_real = torch.unbind(lo0, -1)[0]
# TODO: check correctess of these ops...
coeff = [lo0_real]

else:
Expand Down Expand Up @@ -225,13 +222,13 @@ def reconstruct(self, coeff):

tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

hidft = torch.fft(coeff[0])
hidft = torch.fft(coeff[0], signal_ndim=2)
hidft = math_utils.batch_fftshift2d(hidft)

outdft = tempdft * lo0mask + hidft * hi0mask

reconstruction = math_utils.batch_fftshift2d(outdft)
reconstruction = torch.ifft(reconstruction).real.astype(int)
reconstruction = torch.ifft(reconstruction, signal_ndim=2).real.astype(int)
return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
Expand Down
74 changes: 39 additions & 35 deletions steerable/math_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,45 +22,26 @@
################################################################################
################################################################################

def batch_flip_halves(x, axis):
split = torch.chunk(x, 2, axis)
return torch.cat((split[1], split[0]), dim=axis)
def roll_n(X, axis, n):
f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
front = X[f_idx]
back = X[b_idx]
return torch.cat([back, front], axis)

def batch_fftshift2d(x):
assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor'
assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]'
assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]'
x = batch_flip_halves(x, axis=1) # top,bottom
x = batch_flip_halves(x, axis=2) # left,right
return x
real, imag = torch.unbind(x, -1)
for dim in range(1, len(real.size())):
real = roll_n(real, axis=dim, n=real.size(dim)//2)
imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
return torch.stack((real, imag), -1) # last dim=2 (real&imag)

def batch_ifftshift2d(x):
ndim = x.dim()
assert ndim == 4, 'input tensor must be of shape [N,H,W,2]'
assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]'
x = batch_flip_halves(x, axis=2) # left,right
x = batch_flip_halves(x, axis=1) # top,bottom
return x

# def roll_n(X, axis, n):
# # Source: https://github.com/locuslab/pytorch_fft/blob/master/pytorch_fft/fft/fft.py#L230
# f_idx = tuple(slice(None, None, None) if i != axis else slice(0, n, None) for i in range(X.dim()))
# b_idx = tuple(slice(None, None, None) if i != axis else slice(n, None, None) for i in range(X.dim()))
# front = X[f_idx]
# back = X[b_idx]
# return torch.cat([back, front], axis)

# def fftshift(real, imag):
# for dim in range(1, len(real.size())):
# real = roll_n(real, axis=dim, n=real.size(dim)//2)
# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
# return torch.stack((real, imag), -1) # last dim=2 (real/imag)

# def ifftshift(real, imag):
# for dim in range(len(real.size()) - 1, 0, -1):
# real = roll_n(real, axis=dim, n=real.size(dim)//2)
# imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
# return torch.stack((real, imag), -1) # last dim=2 (real/imag)
real, imag = torch.unbind(x, -1)
for dim in range(len(real.size()) - 1, 0, -1):
real = roll_n(real, axis=dim, n=real.size(dim)//2)
imag = roll_n(imag, axis=dim, n=imag.size(dim)//2)
return torch.stack((real, imag), -1) # last dim=2 (real&imag)

################################################################################
################################################################################
Expand Down Expand Up @@ -92,3 +73,26 @@ def getlist(coeff):
straight = [bands for scale in coeff[1:-1] for bands in scale]
straight = [coeff[0]] + straight + [coeff[-1]]
return straight

################################################################################
# Alternative fftshift implementation

# def batch_flip_halves(x, axis):
# split = torch.chunk(x, 2, axis)
# return torch.cat((split[1], split[0]), dim=axis)

# def batch_fftshift2d_v1(x):
# assert isinstance(x, torch.Tensor), 'input must be a torch.Tensor'
# assert x.dim() == 4, 'input tensor must be of shape [N,H,W,2]'
# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]'
# x = batch_flip_halves(x, axis=1) # top,bottom
# x = batch_flip_halves(x, axis=2) # left,right
# return x

# def batch_ifftshift2d_v1(x):
# ndim = x.dim()
# assert ndim == 4, 'input tensor must be of shape [N,H,W,2]'
# assert x.shape[-1] == 2, 'input tensor must be of shape [N,H,W,2]'
# x = batch_flip_halves(x, axis=2) # left,right
# x = batch_flip_halves(x, axis=1) # top,bottom
# return x
2 changes: 0 additions & 2 deletions tests/test_ifft.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@
################################################################################
# Tolerance checking



all_close_real = np.allclose(np.real(fft_numpy), np.real(fft_torch), atol=tolerance)
all_close_imag = np.allclose(np.imag(fft_numpy), np.imag(fft_torch), atol=tolerance)
print('fft allclose real: {}'.format(all_close_real))
Expand Down

0 comments on commit eb10b14

Please sign in to comment.