Skip to content

Commit

Permalink
Merge pull request #3 from tomrunia/reconstruction
Browse files Browse the repository at this point in the history
Reconstruction
  • Loading branch information
tomrunia authored Dec 13, 2018
2 parents b3e8b6a + 241cfc8 commit 0b6514d
Show file tree
Hide file tree
Showing 9 changed files with 135 additions and 225 deletions.
43 changes: 23 additions & 20 deletions examples/compare_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
pyr_numpy = SCFpyr_NumPy(pyr_height, pyr_nbands, scale_factor=2)
coeff_numpy = pyr_numpy.build(image)
reconstruction_numpy = pyr_numpy.reconstruct(coeff_numpy)
reconstruction_numpy = reconstruction_numpy.astype(np.uint8)

print('#'*60)

################################################################################
# PyTorch
Expand All @@ -58,9 +61,10 @@

pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)
reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
reconstruction_torch = reconstruction_torch.cpu().numpy()[0,]

# Just extract a single example from the batch
# Also moves the example to CPU and NumPy
# Extract first example from the batch and move to CPU
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

################################################################################
Expand All @@ -75,21 +79,20 @@
coeff_level_numpy = coeff_numpy[level]
coeff_level_torch = coeff_torch[level]

assert type(coeff_level_numpy) == type(coeff_level_torch)
assert isinstance(coeff_level_torch, type(coeff_level_numpy))

if isinstance(coeff_level_numpy, np.ndarray):

# Low- or High-Pass
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)
))
min=np.min(coeff_level_numpy), max=np.max(coeff_level_numpy),
mean=np.mean(coeff_level_numpy), std=np.std(coeff_level_numpy)))

print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)
))
min=np.min(coeff_level_torch), max=np.max(coeff_level_torch),
mean=np.mean(coeff_level_torch), std=np.std(coeff_level_torch)))

# Check numerical correctness
assert np.allclose(coeff_level_numpy, coeff_level_torch, atol=tolerance)
Expand All @@ -105,27 +108,27 @@
print(' Orientation Band {}'.format(band))
print(' NumPy. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_numpy), max=np.max(band_numpy),
mean=np.mean(band_numpy), std=np.std(band_numpy)
))
min=np.min(band_numpy), max=np.max(band_numpy),
mean=np.mean(band_numpy), std=np.std(band_numpy)))

print(' PyTorch. min = {min:.3f}, max = {max:.3f},'
' mean = {mean:.3f}, std = {std:.3f}'.format(
min=np.min(band_torch), max=np.max(band_torch),
mean=np.mean(band_torch), std=np.std(band_torch)
))
min=np.min(band_torch), max=np.max(band_torch),
mean=np.mean(band_torch), std=np.std(band_torch)))

# Check numerical correctness
assert np.allclose(band_numpy, band_torch, atol=tolerance)

################################################################################
# Visualize

coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=True)
coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=True)
coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False)
coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False)

cv2.imshow('image', image)
cv2.imshow('coeff numpy', coeff_grid_numpy)
cv2.imshow('coeff torch', coeff_grid_torch)
cv2.imshow('reconstruction numpy', reconstruction_numpy)
cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy))
cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch))
cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8))
cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8))

cv2.waitKey(0)
2 changes: 1 addition & 1 deletion examples/example_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@
if config.visualize:
import cv2
coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
cv2.imshow('image', im_batch_numpy[0,])
cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8))
cv2.imshow('coeff', coeff_grid)
cv2.waitKey(0)

85 changes: 0 additions & 85 deletions examples/example_numpy_reconstruct.py

This file was deleted.

6 changes: 4 additions & 2 deletions examples/example_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import argparse
import time
import numpy as np
import torch

from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
Expand Down Expand Up @@ -79,6 +80,7 @@
if config.visualize:
import cv2
coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
cv2.imshow('image', im_batch_numpy[0,0,])
cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8))
cv2.imshow('coeff', coeff_grid)
cv2.waitKey(0)
cv2.waitKey(0)

65 changes: 22 additions & 43 deletions steerable/SCFpyr_NumPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

# Both are tuples of size 2
low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)

# Selection
log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
Expand All @@ -164,20 +164,21 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

return coeff

################################################################################
# Reconstruction to Image
############################################################################
########################### RECONSTRUCTION #################################
############################################################################

def reconstruct(self, coeff):

if self.nbands != len(coeff[1]):
raise Exception("Unmatched number of orientations")

M, N = coeff[0].shape
log_rad, angle = math_utils.prepare_grid(M, N)
height, width = coeff[0].shape
log_rad, angle = math_utils.prepare_grid(height, width)

Xrcos, Yrcos = math_utils.rcosFn(1, -0.5)
Yrcos = np.sqrt(Yrcos)
YIrcos = np.sqrt(np.abs(1 - Yrcos*Yrcos))
YIrcos = np.sqrt(np.abs(1 - Yrcos**2))

lo0mask = pointOp(log_rad, YIrcos, Xrcos)
hi0mask = pointOp(log_rad, Yrcos, Xrcos)
Expand All @@ -187,15 +188,19 @@ def reconstruct(self, coeff):
hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
outdft = tempdft * lo0mask + hidft * hi0mask

reconstruction = np.fft.ifft2(np.fft.ifftshift(outdft)).real.astype(int)
reconstruction = np.fft.ifftshift(outdft)
reconstruction = np.fft.ifft2(reconstruction)
reconstruction = reconstruction.real

return reconstruction

def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

if len(coeff) == 1:
return np.fft.fftshift(np.fft.fft2(coeff[0]))
dft = np.fft.fft2(coeff[0])
dft = np.fft.fftshift(dft)
return dft


Xrcos = Xrcos - np.log2(self.scale_factor)

####################################################################
Expand All @@ -213,10 +218,9 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
orientdft = np.zeros(coeff[0][0].shape)

for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)

banddft = np.fft.fftshift(np.fft.fft2(coeff[0][b]))
banddft = np.fft.fft2(coeff[0][b])
banddft = np.fft.fftshift(banddft)
orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask

####################################################################
Expand All @@ -225,45 +229,20 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):

dims = np.array(coeff[0][0].shape)

lostart = (np.ceil((dims+0.5)/2) -
np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
lostart = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(np.int32)
loend = lostart + np.ceil((dims-0.5)/2).astype(np.int32)

nlog_rad = log_rad[lostart[0]:loend[0], lostart[1]:loend[1]]
nangle = angle[lostart[0]:loend[0], lostart[1]:loend[1]]
YIrcos = np.sqrt(np.abs(1 - Yrcos * Yrcos))
YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
lomask = pointOp(nlog_rad, YIrcos, Xrcos)

################################################################################

# Recursive call for image reconstruction
nresdft = self._reconstruct_levels(coeff[1:], nlog_rad, Xrcos, Yrcos, nangle)

resdft = np.zeros(dims, 'complex')
resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask

return resdft + orientdft


################################################################################
################################################################################
# Work in Progress

class ComplexSteerablePyramid():

def __init__(self, height, nbands):
self._height = height # including low-pass and high-pass
self._nbands = nbands # number of orientation bands
self._coeff = [None]*self._height

def set_level(self, level, coeff):
self._coeff[level] = coeff

def get_level(self, level):
return self._coeff[level]

def level_size(self, level):
if level == 0:
# High-pass
return self._coeff[level].shape
elif level == self._nbands:
# Low-pass
return self._coeff[level][0].shape
# Intermediate levels
return self._coeff[level][0].shape
Loading

0 comments on commit 0b6514d

Please sign in to comment.