Skip to content

Commit

Permalink
Merge pull request #2 from tomrunia/reconstruction
Browse files Browse the repository at this point in the history
bugfix low-pass
  • Loading branch information
tomrunia authored Dec 11, 2018
2 parents 72169ac + eb10b14 commit b3e8b6a
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 97 deletions.
115 changes: 83 additions & 32 deletions examples/compare_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,63 +18,114 @@

import numpy as np
import torch
import cv2

from steerable.SCFpyr import SCFpyr
from steerable.SCFpyr_NumPy import SCFpyr_NumPy
from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
from steerable.visualize import visualize

import cortex.vision
import cv2
import steerable.utils as utils

################################################################################
################################################################################
# Common

image_file = './assets/lena.jpg'
im = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
im = cortex.vision.resize(im, out_height=200, out_width=200)
im = im.astype(np.float32)/255.
image = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (200,200))

# Number of pyramid levels
pyr_height = 5

################################################################################
# Number of orientation bands
pyr_nbands = 4

# Tolerance for error checking
tolerance = 1e-3

################################################################################
# NumPy

# Build the complex steerable pyramid
pyr = SCFpyr(height=pyr_height)
coeff_numpy = pyr.build(im)
pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
coeff_numpy = pyr_numpy.build(image)
reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)

################################################################################
# PyTorch

device = torch.device('cpu')
batch_size = 16
device = torch.device('cuda:0')

# Create a batch of images
im_batch = np.tile(im, (batch_size,1,1))
im_batch = torch.from_numpy(image[None,None,:,:])
im_batch = im_batch.to(device).float()

# Move to Torch on the GPU
im_batch = torch.from_numpy(im_batch)
im_batch = im_batch.to(device)
pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)

# Build the complex steerable pyramid
pyr = SCFpyr_PyTorch(height=pyr_height, scale_factor=2, device=device)
coeff_torch = pyr.build(im_batch)

# HighPass: coeff[0] : highpass
# BandPass Scale 1: coeff[1][0], coeff[1][1], coeff[1][2], coeff[1][3]
# BandPass Scale 2: coeff[2][0], coeff[2][1], coeff[2][2], coeff[2][3]
# Just extract a single example from the batch
# Also moves the example to CPU and NumPy
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

################################################################################
# Check correctness

print('#'*60)
assert len(coeff_numpy) == len(coeff_torch)

for level, _ in enumerate(coeff_numpy):

print('Pyramid Level {level}'.format(level=level))
coeff_level_numpy = coeff_numpy[level]
coeff_level_torch = coeff_torch[level]

assert type(coeff_level_numpy) == type(coeff_level_torch)

if isinstance(coeff_level_numpy, np.ndarray):

# Low- or High-Pass
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)
))
print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)
))

# Check numerical correctness
assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance)

elif isinstance(coeff_level_numpy, list):

# Intermediate bands
for band, _ in enumerate(coeff_level_numpy):

band_numpy = coeff_level_numpy[band]
band_torch = coeff_level_torch[band]

print(' Orientation Band {}'.format(band))
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_numpy), max=np.max(band_numpy),
mean=np.mean(band_numpy), std=np.std(band_numpy)
))
print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_torch), max=np.max(band_torch),
mean=np.mean(band_torch), std=np.std(band_torch)
))

# Check numerical correctness
assert np.allclose(band_numpy, band_torch, atol=tolerance)

bands_viz_numpy = visualize(coeff_numpy)
bands_viz_torch = visualize(coeff_torch, example_idx=0)

cv2.imshow('NumPy Bands', bands_viz_numpy)
cv2.imshow('PyTorch Bands', bands_viz_torch)
cv2.waitKey(0)
################################################################################
# Visualize

coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=True)
coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=True)

cv2.imshow('image', image)
cv2.imshow('coeff numpy', coeff_grid_numpy)
cv2.imshow('coeff torch', coeff_grid_torch)
cv2.imshow('reconstruction numpy', reconstruction_numpy)

cv2.waitKey(0)
2 changes: 1 addition & 1 deletion examples/example_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
from __future__ import division
from __future__ import print_function

import time
import argparse
import numpy as np
import time

from steerable.SCFpyr_NumPy import SCFpyr_NumPy
import steerable.utils as utils
Expand Down
85 changes: 85 additions & 0 deletions examples/example_numpy_reconstruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# MIT License
#
# Copyright (c) 2018 Tom Runia
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to conditions.
#
# Author: Tom Runia
# Date Created: 2018-12-04

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import numpy as np
import cv2

from steerable.SCFpyr_NumPy import SCFpyr_NumPy
import steerable.utils as utils

################################################################################
################################################################################

if __name__ == "__main__":

parser = argparse.ArgumentParser()
parser.add_argument('--image_file', type=str, default='./assets/patagonia.jpg')
parser.add_argument('--batch_size', type=int, default='1')
parser.add_argument('--image_size', type=int, default='200')
parser.add_argument('--pyr_nlevels', type=int, default='5')
parser.add_argument('--pyr_nbands', type=int, default='4')
parser.add_argument('--pyr_scale_factor', type=int, default='2')
parser.add_argument('--visualize', type=bool, default=True)
config = parser.parse_args()

############################################################################
# Build the complex steerable pyramid

pyr = SCFpyr_NumPy(
height=config.pyr_nlevels,
nbands=config.pyr_nbands,
scale_factor=config.pyr_scale_factor,
)

############################################################################
# Create a batch and feed-forward

image = cv2.imread('./assets/lena.jpg', cv2.IMREAD_GRAYSCALE)
image = cv2.resize(image, (200,200))

# TODO: rescaling to the range [0,1] does not work...?
#image = image.astype(np.float32)/255.

# Decompose into steerable pyramid
coeff = pyr.build(image)

# Reconstruct the image from the pyramid coefficients
reconstruction = pyr.reconstruct(coeff)

reconstruction = reconstruction.astype(np.float32)
reconstruction = np.ascontiguousarray(reconstruction)
reconstruction /= 255.

############################################################################

tolerance = 1e-4
print('image', np.mean(image), np.std(image))
print('reconstruction', np.mean(reconstruction), np.std(reconstruction))
print('allclose', np.allclose(image, reconstruction, atol=tolerance))

############################################################################
# Visualization

if config.visualize:
coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
cv2.imshow('image', image)
cv2.imshow('reconstruction', reconstruction)
cv2.imshow('coeff', coeff_grid)
cv2.waitKey(0)

14 changes: 7 additions & 7 deletions steerable/SCFpyr_NumPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ def build(self, im):
assert len(im.shape) == 2, 'Input im must be grayscale'
height, width = im.shape

# Check whether im shape allows the pyramid M
max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2)
assert max_height_pyr >= self.height, 'Cannot buid pyramid heigher than {} levels'.format(max_height_pyr)
# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))

# Prepare a grid
log_rad, angle = math_utils.prepare_grid(height, width)
Expand All @@ -85,7 +85,6 @@ def build(self, im):
Yrcos = np.sqrt(Yrcos)

YIrcos = np.sqrt(1 - Yrcos**2)

lo0mask = pointOp(log_rad, YIrcos, Xrcos)
hi0mask = pointOp(log_rad, Yrcos, Xrcos)

Expand All @@ -102,7 +101,6 @@ def build(self, im):
hi0dft = imdft * hi0mask
hi0 = np.fft.ifft2(np.fft.ifftshift(hi0dft))
coeff.insert(0, hi0.real)

return coeff


Expand All @@ -111,7 +109,8 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
if height <= 1:

# Low-pass
lo0 = np.fft.ifft2(np.fft.ifftshift(lodft))
lo0 = np.fft.ifftshift(lodft)
lo0 = np.fft.ifft2(lo0)
coeff = [lo0.real]

else:
Expand Down Expand Up @@ -188,7 +187,8 @@ def reconstruct(self, coeff):
hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
outdft = tempdft * lo0mask + hidft * hi0mask

return np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int)
reconstruction = np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int)
return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

Expand Down
38 changes: 18 additions & 20 deletions steerable/SCFpyr_PyTorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,12 +77,12 @@ def build(self, im_batch):
assert im_batch.dim() == 4, 'Image batch must be of shape [N,C,H,W]'
assert im_batch.shape[1] == 1, 'Second dimension must be 1 encoding grayscale image'

height, width = im_batch.shape[2], im_batch.shape[2]
im_batch = im_batch.squeeze(1) # flatten channels dim

# Check whether im shape allows the pyramid M
max_height_pyr = int(np.floor(np.log2(min(width, height))) - 2)
assert max_height_pyr >= self.height, 'Cannot buid pyramid with more than {} levels'.format(max_height_pyr)
height, width = im_batch.shape[2], im_batch.shape[1]

# Check whether image size is sufficient for number of levels
if self.height > int(np.floor(np.log2(min(width, height))) - 2):
raise RuntimeError('Cannot build {} levels, image too small.'.format(self.height))

# Prepare a grid
log_rad, angle = math_utils.prepare_grid(height, width)
Expand Down Expand Up @@ -112,12 +112,10 @@ def build(self, im_batch):

# High-pass
hi0dft = batch_dft * hi0mask

hi0 = torch.ifft(hi0dft, signal_ndim=2)
hi0 = math_utils.batch_ifftshift2d(hi0)
hi0 = math_utils.batch_ifftshift2d(hi0dft)
hi0 = torch.ifft(hi0, signal_ndim=2)
hi0_real = torch.unbind(hi0, -1)[0]
coeff.insert(0, hi0_real)

return coeff


Expand All @@ -126,10 +124,9 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):
if height <= 1:

# Low-pass
lo0 = torch.ifft(lodft, signal_ndim=2)
lo0 = math_utils.batch_fftshift2d(lo0)
lo0 = math_utils.batch_ifftshift2d(lodft)
lo0 = torch.ifft(lo0, signal_ndim=2)
lo0_real = torch.unbind(lo0, -1)[0]
# TODO: check correctess of these ops...
coeff = [lo0_real]

else:
Expand Down Expand Up @@ -210,15 +207,13 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

def reconstruct(self, coeff):

raise NotImplementedError('Reconstruction using PyTorch is work in progres...')

if self.nbands != len(coeff[1]):
raise Exception("Unmatched number of orientations")

M, N = coeff[0].shape
log_rad, angle = math_utils.utils.prepare_grid(M, N)
height, width = coeff[0].shape[2], coeff[0].shape[1]
log_rad, angle = math_utils.prepare_grid(height, width)

Xrcos, Yrcos = math_utils.utils.rcosFn(1, -0.5)
Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
Yrcos = np.sqrt(Yrcos)
YIrcos = np.sqrt(np.abs(1 - Yrcos*Yrcos))

Expand All @@ -227,11 +222,14 @@ def reconstruct(self, coeff):

tempdft = self._reconstruct_levels(coeff[1:], log_rad, Xrcos, Yrcos, angle)

hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
hidft = torch.fft(coeff[0], signal_ndim=2)
hidft = math_utils.batch_fftshift2d(hidft)

outdft = tempdft * lo0mask + hidft * hi0mask

return np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int)

reconstruction = math_utils.batch_fftshift2d(outdft)
reconstruction = torch.ifft(reconstruction, signal_ndim=2).real.astype(int)
return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

Expand Down
Loading

0 comments on commit b3e8b6a

Please sign in to comment.