Skip to content

Commit

Permalink
reconstruction finished
Browse files Browse the repository at this point in the history
  • Loading branch information
tomrunia committed Dec 13, 2018
1 parent 1e32704 commit 7463f77
Show file tree
Hide file tree
Showing 9 changed files with 57 additions and 392 deletions.
31 changes: 8 additions & 23 deletions examples/compare_both.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,20 +61,12 @@

pyr_torch = SCFpyr_PyTorch(pyr_height, pyr_nbands, device=device)
coeff_torch = pyr_torch.build(im_batch)
reconstruction_torch_v2 = pyr_torch.reconstruct(coeff_torch)
#reconstruction_torch_v2 = reconstruction_torch_v2.cpu().numpy()[0,]
reconstruction_torch = pyr_torch.reconstruct(coeff_torch)
reconstruction_torch = reconstruction_torch.cpu().numpy()[0,]

# Just extract a single example from the batch
# Also moves the example to CPU and NumPy
# Extract first example from the batch and move to CPU
coeff_torch = utils.extract_from_batch(coeff_torch, 0)

# NOTE: reconstruction using NumPy implementation
# Gives the same result, so decomposition is correct (!)
# reconstruction_torch = pyr_numpy.reconstruct(coeff_torch)
# reconstruction_torch = reconstruction_torch.astype(np.uint8)

#exit()

################################################################################
# Check correctness

Expand Down Expand Up @@ -133,17 +125,10 @@
coeff_grid_numpy = utils.make_grid_coeff(coeff_numpy, normalize=False)
coeff_grid_torch = utils.make_grid_coeff(coeff_torch, normalize=False)

# import cortex.vision
# reconstruction_torch = np.ascontiguousarray(reconstruction_torch[0], np.float32)
# reconstruction_numpy = np.ascontiguousarray(reconstruction_numpy, np.float32)
# reconstruction_torch = cortex.vision.normalize_for_display(reconstruction_torch)
# reconstruction_numpy = cortex.vision.normalize_for_display(reconstruction_numpy)

#cv2.imshow('image', image)
#cv2.imshow('coeff numpy', coeff_grid_numpy)
#cv2.imshow('coeff torch', coeff_grid_torch)
cv2.imshow('reconstruction numpy', reconstruction_numpy)
#cv2.imshow('reconstruction torch', reconstruction_torch)
cv2.imshow('reconstruction torch', reconstruction_torch_v2)
cv2.imshow('image', image)
cv2.imshow('coeff numpy', np.ascontiguousarray(coeff_grid_numpy))
cv2.imshow('coeff torch', np.ascontiguousarray(coeff_grid_torch))
cv2.imshow('reconstruction numpy', reconstruction_numpy.astype(np.uint8))
cv2.imshow('reconstruction torch', reconstruction_torch.astype(np.uint8))

cv2.waitKey(0)
69 changes: 0 additions & 69 deletions examples/debug_reconstruction.py

This file was deleted.

3 changes: 2 additions & 1 deletion examples/example_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import time
import argparse
import numpy as np

from steerable.SCFpyr_NumPy import SCFpyr_NumPy
import steerable.utils as utils
Expand Down Expand Up @@ -71,7 +72,7 @@
if config.visualize:
import cv2
coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
cv2.imshow('image', im_batch_numpy[0,])
cv2.imshow('image', (im_batch_numpy[0,]*255.).astype(np.uint8))
cv2.imshow('coeff', coeff_grid)
cv2.waitKey(0)

85 changes: 0 additions & 85 deletions examples/example_numpy_reconstruct.py

This file was deleted.

3 changes: 2 additions & 1 deletion examples/example_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import argparse
import time
import numpy as np
import torch

from steerable.SCFpyr_PyTorch import SCFpyr_PyTorch
Expand Down Expand Up @@ -79,7 +80,7 @@
if config.visualize:
import cv2
coeff_grid = utils.make_grid_coeff(coeff, normalize=True)
cv2.imshow('image', im_batch_numpy[0,0,])
cv2.imshow('image', (im_batch_numpy[0,0,]*255.).astype(np.uint8))
cv2.imshow('coeff', coeff_grid)
cv2.waitKey(0)

86 changes: 4 additions & 82 deletions steerable/SCFpyr_NumPy.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def _build_levels(self, lodft, log_rad, angle, Xrcos, Yrcos, height):

# Both are tuples of size 2
low_ind_start = (np.ceil((dims+0.5)/2) - np.ceil((np.ceil((dims-0.5)/2)+0.5)/2)).astype(int)
low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)
low_ind_end = (low_ind_start + np.ceil((dims-0.5)/2)).astype(int)

# Selection
log_rad = log_rad[low_ind_start[0]:low_ind_end[0], low_ind_start[1]:low_ind_end[1]]
Expand Down Expand Up @@ -188,28 +188,9 @@ def reconstruct(self, coeff):
hidft = np.fft.fftshift(np.fft.fft2(coeff[0]))
outdft = tempdft * lo0mask + hidft * hi0mask

real = outdft.real
imag = outdft.imag
print(' [numpy] levels remaining {}. outdft real ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), real.mean().item(), real.std().item(), real.sum().item()
))
print(' [numpy] levels remaining {}. outdft imag ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), imag.mean().item(), imag.std().item(), imag.sum().item()
))

reconstruction = np.fft.ifftshift(outdft)
reconstruction = np.fft.ifft2(reconstruction)

real = reconstruction.real
imag = reconstruction.imag
print(' [numpy] levels remaining {}. outdft real ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), real.mean().item(), real.std().item(), real.sum().item()
))
print(' [numpy] levels remaining {}. outdft imag ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), imag.mean().item(), imag.std().item(), imag.sum().item()
))

reconstruction = reconstruction.real.astype(np.uint8)
reconstruction = reconstruction.real

return reconstruction

Expand All @@ -218,15 +199,6 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
if len(coeff) == 1:
dft = np.fft.fft2(coeff[0])
dft = np.fft.fftshift(dft)

real, imag = dft.real, dft.imag
print(' [numpy] levels remaining {}. dft real ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), real.mean().item(), real.std().item(), real.sum().item()
))
print(' [numpy] levels remaining {}. dft imag ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), imag.mean().item(), imag.std().item(), imag.sum().item()
))

return dft

Xrcos = Xrcos - np.log2(self.scale_factor)
Expand All @@ -246,21 +218,11 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
orientdft = np.zeros(coeff[0][0].shape)

for b in range(self.nbands):

anglemask = pointOp(angle, Ycosn, Xcosn + np.pi * b/self.nbands)

banddft = np.fft.fftshift(np.fft.fft2(coeff[0][b]))
banddft = np.fft.fft2(coeff[0][b])
banddft = np.fft.fftshift(banddft)
orientdft = orientdft + np.power(np.complex(0, 1), order) * banddft * anglemask * himask

real = orientdft.real
imag = orientdft.imag
print(' [numpy] levels remaining {}. orientdft real ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), real.mean().item(), real.std().item(), real.sum().item()
))
print(' [numpy] levels remaining {}. orientdft imag ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), imag.mean().item(), imag.std().item(), imag.sum().item()
))

####################################################################
########## Lowpass component are upsampled and convoluted ##########
####################################################################
Expand All @@ -275,10 +237,6 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
YIrcos = np.sqrt(np.abs(1 - Yrcos**2))
lomask = pointOp(nlog_rad, YIrcos, Xrcos)

print(' [numpy] levels remaining {}. nlog_rad = {:.3f}, nangle = {:.3f}, YIrcos = {:.3f}, lomask = {:.3f}'.format(
len(coeff), nlog_rad.sum(), nangle.sum(), YIrcos.sum(), lomask.sum()
))

################################################################################

# Recursive call for image reconstruction
Expand All @@ -287,40 +245,4 @@ def _reconstruct_levels(self, coeff, log_rad, Xrcos, Yrcos, angle):
resdft = np.zeros(dims, 'complex')
resdft[lostart[0]:loend[0], lostart[1]:loend[1]] = nresdft * lomask

real, imag = nresdft.real, nresdft.imag
print(' [numpy] levels remaining {}. nresdft real ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), real.mean().item(), real.std().item(), real.sum().item()
))
print(' [numpy] levels remaining {}. nresdft imag ({:.3f}, {:.3f}, {:.3f})'.format(
len(coeff), imag.mean().item(), imag.std().item(), imag.sum().item()
))

return resdft + orientdft


################################################################################
################################################################################
# Work in Progress

class ComplexSteerablePyramid():

def __init__(self, height, nbands):
self._height = height # including low-pass and high-pass
self._nbands = nbands # number of orientation bands
self._coeff = [None]*self._height

def set_level(self, level, coeff):
self._coeff[level] = coeff

def get_level(self, level):
return self._coeff[level]

def level_size(self, level):
if level == 0:
# High-pass
return self._coeff[level].shape
elif level == self._nbands:
# Low-pass
return self._coeff[level][0].shape
# Intermediate levels
return self._coeff[level][0].shape
Loading

0 comments on commit 7463f77

Please sign in to comment.