|
|
|
@ -0,0 +1,387 @@ |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
""" |
|
|
|
Useful utilities for testing the 2-D DTCWT with synthetic images |
|
|
|
License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE |
|
|
|
Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa |
|
|
|
""" |
|
|
|
|
|
|
|
import pywt |
|
|
|
import torch |
|
|
|
import numpy as np |
|
|
|
import torch.nn.functional as F |
|
|
|
from torch.autograd import Function |
|
|
|
|
|
|
|
|
|
|
|
def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1): |
|
|
|
""" 1D synthesis filter bank of an image tensor |
|
|
|
""" |
|
|
|
C = lo.shape[1] |
|
|
|
d = dim % 4 |
|
|
|
# If g0, g1 are not tensors, make them. If they are, then assume that they |
|
|
|
# are in the right order |
|
|
|
if not isinstance(g0, torch.Tensor): |
|
|
|
g0 = torch.tensor(np.copy(np.array(g0).ravel()), |
|
|
|
dtype=torch.float, device=lo.device) |
|
|
|
if not isinstance(g1, torch.Tensor): |
|
|
|
g1 = torch.tensor(np.copy(np.array(g1).ravel()), |
|
|
|
dtype=torch.float, device=lo.device) |
|
|
|
L = g0.numel() |
|
|
|
shape = [1, 1, 1, 1] |
|
|
|
shape[d] = L |
|
|
|
N = 2*lo.shape[d] |
|
|
|
# If g aren't in the right shape, make them so |
|
|
|
if g0.shape != tuple(shape): |
|
|
|
g0 = g0.reshape(*shape) |
|
|
|
if g1.shape != tuple(shape): |
|
|
|
g1 = g1.reshape(*shape) |
|
|
|
|
|
|
|
s = (2, 1) if d == 2 else (1, 2) |
|
|
|
g0 = torch.cat([g0]*C, dim=0) |
|
|
|
g1 = torch.cat([g1]*C, dim=0) |
|
|
|
if mode == 'per' or mode == 'periodization': |
|
|
|
y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \ |
|
|
|
F.conv_transpose2d(hi, g1, stride=s, groups=C) |
|
|
|
if d == 2: |
|
|
|
y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2] |
|
|
|
y = y[:, :, :N] |
|
|
|
else: |
|
|
|
y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2] |
|
|
|
y = y[:, :, :, :N] |
|
|
|
y = roll(y, 1-L//2, dim=dim) |
|
|
|
else: |
|
|
|
if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \ |
|
|
|
mode == 'periodic': |
|
|
|
pad = (L-2, 0) if d == 2 else (0, L-2) |
|
|
|
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \ |
|
|
|
F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C) |
|
|
|
else: |
|
|
|
raise ValueError('Unkown pad type: {}'.format(mode)) |
|
|
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode): |
|
|
|
mode = int_to_mode(mode) |
|
|
|
|
|
|
|
lh, hl, hh = torch.unbind(highs, dim=2) |
|
|
|
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) |
|
|
|
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) |
|
|
|
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) |
|
|
|
|
|
|
|
return y |
|
|
|
|
|
|
|
|
|
|
|
def roll(x, n, dim, make_even=False): |
|
|
|
if n < 0: |
|
|
|
n = x.shape[dim] + n |
|
|
|
|
|
|
|
if make_even and x.shape[dim] % 2 == 1: |
|
|
|
end = 1 |
|
|
|
else: |
|
|
|
end = 0 |
|
|
|
|
|
|
|
if dim == 0: |
|
|
|
return torch.cat((x[-n:], x[:-n+end]), dim=0) |
|
|
|
elif dim == 1: |
|
|
|
return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1) |
|
|
|
elif dim == 2 or dim == -2: |
|
|
|
return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2) |
|
|
|
elif dim == 3 or dim == -1: |
|
|
|
return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3) |
|
|
|
|
|
|
|
|
|
|
|
def int_to_mode(mode): |
|
|
|
if mode == 0: |
|
|
|
return 'zero' |
|
|
|
elif mode == 1: |
|
|
|
return 'symmetric' |
|
|
|
elif mode == 2: |
|
|
|
return 'periodization' |
|
|
|
elif mode == 3: |
|
|
|
return 'constant' |
|
|
|
elif mode == 4: |
|
|
|
return 'reflect' |
|
|
|
elif mode == 5: |
|
|
|
return 'replicate' |
|
|
|
elif mode == 6: |
|
|
|
return 'periodic' |
|
|
|
else: |
|
|
|
raise ValueError('Unkown pad type: {}'.format(mode)) |
|
|
|
|
|
|
|
|
|
|
|
def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None): |
|
|
|
""" |
|
|
|
Prepares the filters to be of the right form for the sfb2d function. In |
|
|
|
particular, makes the tensors the right shape. It does not mirror image them |
|
|
|
as as sfb2d uses conv2d_transpose which acts like normal convolution. |
|
|
|
Inputs: |
|
|
|
g0_col (array-like): low pass column filter bank |
|
|
|
g1_col (array-like): high pass column filter bank |
|
|
|
g0_row (array-like): low pass row filter bank. If none, will assume the |
|
|
|
same as column filter |
|
|
|
g1_row (array-like): high pass row filter bank. If none, will assume the |
|
|
|
same as column filter |
|
|
|
device: which device to put the tensors on to |
|
|
|
Returns: |
|
|
|
(g0_col, g1_col, g0_row, g1_row) |
|
|
|
""" |
|
|
|
g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device) |
|
|
|
if g0_row is None: |
|
|
|
g0_row, g1_row = g0_col, g1_col |
|
|
|
else: |
|
|
|
g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device) |
|
|
|
|
|
|
|
g0_col = g0_col.reshape((1, 1, -1, 1)) |
|
|
|
g1_col = g1_col.reshape((1, 1, -1, 1)) |
|
|
|
g0_row = g0_row.reshape((1, 1, 1, -1)) |
|
|
|
g1_row = g1_row.reshape((1, 1, 1, -1)) |
|
|
|
|
|
|
|
return g0_col, g1_col, g0_row, g1_row |
|
|
|
|
|
|
|
|
|
|
|
def prep_filt_sfb1d(g0, g1, device=None): |
|
|
|
""" |
|
|
|
Prepares the filters to be of the right form for the sfb1d function. In |
|
|
|
particular, makes the tensors the right shape. It does not mirror image them |
|
|
|
as as sfb2d uses conv2d_transpose which acts like normal convolution. |
|
|
|
Inputs: |
|
|
|
g0 (array-like): low pass filter bank |
|
|
|
g1 (array-like): high pass filter bank |
|
|
|
device: which device to put the tensors on to |
|
|
|
Returns: |
|
|
|
(g0, g1) |
|
|
|
""" |
|
|
|
g0 = np.array(g0).ravel() |
|
|
|
g1 = np.array(g1).ravel() |
|
|
|
t = torch.get_default_dtype() |
|
|
|
g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1)) |
|
|
|
g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1)) |
|
|
|
|
|
|
|
return g0, g1 |
|
|
|
|
|
|
|
|
|
|
|
def mode_to_int(mode): |
|
|
|
if mode == 'zero': |
|
|
|
return 0 |
|
|
|
elif mode == 'symmetric': |
|
|
|
return 1 |
|
|
|
elif mode == 'per' or mode == 'periodization': |
|
|
|
return 2 |
|
|
|
elif mode == 'constant': |
|
|
|
return 3 |
|
|
|
elif mode == 'reflect': |
|
|
|
return 4 |
|
|
|
elif mode == 'replicate': |
|
|
|
return 5 |
|
|
|
elif mode == 'periodic': |
|
|
|
return 6 |
|
|
|
else: |
|
|
|
raise ValueError('Unkown pad type: {}'.format(mode)) |
|
|
|
|
|
|
|
|
|
|
|
def afb1d(x, h0, h1, mode='zero', dim=-1): |
|
|
|
""" 1D analysis filter bank (along one dimension only) of an image |
|
|
|
Inputs: |
|
|
|
x (tensor): 4D input with the last two dimensions the spatial input |
|
|
|
h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1, |
|
|
|
h, 1) or (1, 1, 1, w) |
|
|
|
h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1, |
|
|
|
h, 1) or (1, 1, 1, w) |
|
|
|
mode (str): padding method |
|
|
|
dim (int) - dimension of filtering. d=2 is for a vertical filter (called |
|
|
|
column filtering but filters across the rows). d=3 is for a |
|
|
|
horizontal filter, (called row filtering but filters across the |
|
|
|
columns). |
|
|
|
Returns: |
|
|
|
lohi: lowpass and highpass subbands concatenated along the channel |
|
|
|
dimension |
|
|
|
""" |
|
|
|
C = x.shape[1] |
|
|
|
# Convert the dim to positive |
|
|
|
d = dim % 4 |
|
|
|
s = (2, 1) if d == 2 else (1, 2) |
|
|
|
N = x.shape[d] |
|
|
|
# If h0, h1 are not tensors, make them. If they are, then assume that they |
|
|
|
# are in the right order |
|
|
|
if not isinstance(h0, torch.Tensor): |
|
|
|
h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]), |
|
|
|
dtype=torch.float, device=x.device) |
|
|
|
if not isinstance(h1, torch.Tensor): |
|
|
|
h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]), |
|
|
|
dtype=torch.float, device=x.device) |
|
|
|
L = h0.numel() |
|
|
|
L2 = L // 2 |
|
|
|
shape = [1, 1, 1, 1] |
|
|
|
shape[d] = L |
|
|
|
# If h aren't in the right shape, make them so |
|
|
|
if h0.shape != tuple(shape): |
|
|
|
h0 = h0.reshape(*shape) |
|
|
|
if h1.shape != tuple(shape): |
|
|
|
h1 = h1.reshape(*shape) |
|
|
|
h = torch.cat([h0, h1] * C, dim=0) |
|
|
|
|
|
|
|
if mode == 'per' or mode == 'periodization': |
|
|
|
if x.shape[dim] % 2 == 1: |
|
|
|
if d == 2: |
|
|
|
x = torch.cat((x, x[:, :, -1:]), dim=2) |
|
|
|
else: |
|
|
|
x = torch.cat((x, x[:, :, :, -1:]), dim=3) |
|
|
|
N += 1 |
|
|
|
x = roll(x, -L2, dim=d) |
|
|
|
pad = (L-1, 0) if d == 2 else (0, L-1) |
|
|
|
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) |
|
|
|
N2 = N//2 |
|
|
|
if d == 2: |
|
|
|
lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2] |
|
|
|
lohi = lohi[:, :, :N2] |
|
|
|
else: |
|
|
|
lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2] |
|
|
|
lohi = lohi[:, :, :, :N2] |
|
|
|
else: |
|
|
|
# Calculate the pad size |
|
|
|
outsize = pywt.dwt_coeff_len(N, L, mode=mode) |
|
|
|
p = 2 * (outsize - 1) - N + L |
|
|
|
if mode == 'zero': |
|
|
|
# Sadly, pytorch only allows for same padding before and after, if |
|
|
|
# we need to do more padding after for odd length signals, have to |
|
|
|
# prepad |
|
|
|
if p % 2 == 1: |
|
|
|
pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0) |
|
|
|
x = F.pad(x, pad) |
|
|
|
pad = (p//2, 0) if d == 2 else (0, p//2) |
|
|
|
# Calculate the high and lowpass |
|
|
|
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C) |
|
|
|
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic': |
|
|
|
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0) |
|
|
|
x = mypad(x, pad=pad, mode=mode) |
|
|
|
lohi = F.conv2d(x, h, stride=s, groups=C) |
|
|
|
else: |
|
|
|
raise ValueError('Unkown pad type: {}'.format(mode)) |
|
|
|
|
|
|
|
return lohi |
|
|
|
|
|
|
|
|
|
|
|
def mypad(x, pad, mode='constant', value=0): |
|
|
|
""" Function to do numpy like padding on tensors. Only works for 2-D |
|
|
|
padding. |
|
|
|
Inputs: |
|
|
|
x (tensor): tensor to pad |
|
|
|
pad (tuple): tuple of (left, right, top, bottom) pad sizes |
|
|
|
mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or |
|
|
|
'zero'. The padding technique. |
|
|
|
""" |
|
|
|
if mode == 'symmetric': |
|
|
|
# Vertical only |
|
|
|
if pad[0] == 0 and pad[1] == 0: |
|
|
|
m1, m2 = pad[2], pad[3] |
|
|
|
l = x.shape[-2] # noqa |
|
|
|
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) |
|
|
|
return x[:, :, xe] |
|
|
|
# horizontal only |
|
|
|
elif pad[2] == 0 and pad[3] == 0: |
|
|
|
m1, m2 = pad[0], pad[1] |
|
|
|
l = x.shape[-1] # noqa |
|
|
|
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5) |
|
|
|
return x[:, :, :, xe] |
|
|
|
# Both |
|
|
|
else: |
|
|
|
m1, m2 = pad[0], pad[1] |
|
|
|
l1 = x.shape[-1] |
|
|
|
xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5) |
|
|
|
m1, m2 = pad[2], pad[3] |
|
|
|
l2 = x.shape[-2] |
|
|
|
xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5) |
|
|
|
i = np.outer(xe_col, np.ones(xe_row.shape[0])) |
|
|
|
j = np.outer(np.ones(xe_col.shape[0]), xe_row) |
|
|
|
return x[:, :, i, j] |
|
|
|
elif mode == 'periodic': |
|
|
|
# Vertical only |
|
|
|
if pad[0] == 0 and pad[1] == 0: |
|
|
|
xe = np.arange(x.shape[-2]) |
|
|
|
xe = np.pad(xe, (pad[2], pad[3]), mode='wrap') |
|
|
|
return x[:, :, xe] |
|
|
|
# Horizontal only |
|
|
|
elif pad[2] == 0 and pad[3] == 0: |
|
|
|
xe = np.arange(x.shape[-1]) |
|
|
|
xe = np.pad(xe, (pad[0], pad[1]), mode='wrap') |
|
|
|
return x[:, :, :, xe] |
|
|
|
# Both |
|
|
|
else: |
|
|
|
xe_col = np.arange(x.shape[-2]) |
|
|
|
xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap') |
|
|
|
xe_row = np.arange(x.shape[-1]) |
|
|
|
xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap') |
|
|
|
i = np.outer(xe_col, np.ones(xe_row.shape[0])) |
|
|
|
j = np.outer(np.ones(xe_col.shape[0]), xe_row) |
|
|
|
return x[:, :, i, j] |
|
|
|
|
|
|
|
elif mode == 'constant' or mode == 'reflect' or mode == 'replicate': |
|
|
|
return F.pad(x, pad, mode, value) |
|
|
|
elif mode == 'zero': |
|
|
|
return F.pad(x, pad) |
|
|
|
else: |
|
|
|
raise ValueError('Unkown pad type: {}'.format(mode)) |
|
|
|
|
|
|
|
|
|
|
|
def reflect(x, minx, maxx): |
|
|
|
"""Reflect the values in matrix *x* about the scalar values *minx* and |
|
|
|
*maxx*. Hence a vector *x* containing a long linearly increasing series is |
|
|
|
converted into a waveform which ramps linearly up and down between *minx* |
|
|
|
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers + |
|
|
|
0.5), the ramps will have repeated max and min samples. |
|
|
|
.. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013 |
|
|
|
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999. |
|
|
|
""" |
|
|
|
x = np.asanyarray(x) |
|
|
|
rng = maxx - minx |
|
|
|
rng_by_2 = 2 * rng |
|
|
|
mod = np.fmod(x - minx, rng_by_2) |
|
|
|
normed_mod = np.where(mod < 0, mod + rng_by_2, mod) |
|
|
|
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx |
|
|
|
return np.array(out, dtype=x.dtype) |
|
|
|
|
|
|
|
|
|
|
|
class SFB2D(Function): |
|
|
|
""" Does a single level 2d wavelet decomposition of an input. Does separate |
|
|
|
row and column filtering by two calls to |
|
|
|
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d` |
|
|
|
Needs to have the tensors in the right form. Because this function defines |
|
|
|
its own backward pass, saves on memory by not having to save the input |
|
|
|
tensors. |
|
|
|
Inputs: |
|
|
|
x (torch.Tensor): Input to decompose |
|
|
|
h0_row: row lowpass |
|
|
|
h1_row: row highpass |
|
|
|
h0_col: col lowpass |
|
|
|
h1_col: col highpass |
|
|
|
mode (int): use mode_to_int to get the int code here |
|
|
|
We encode the mode as an integer rather than a string as gradcheck causes an |
|
|
|
error when a string is provided. |
|
|
|
Returns: |
|
|
|
y: Tensor of shape (N, C*4, H, W) |
|
|
|
""" |
|
|
|
@staticmethod |
|
|
|
def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode): |
|
|
|
mode = int_to_mode(mode) |
|
|
|
ctx.mode = mode |
|
|
|
ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col) |
|
|
|
|
|
|
|
lh, hl, hh = torch.unbind(highs, dim=2) |
|
|
|
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2) |
|
|
|
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2) |
|
|
|
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3) |
|
|
|
return y |
|
|
|
|
|
|
|
@staticmethod |
|
|
|
def backward(ctx, dy): |
|
|
|
dlow, dhigh = None, None |
|
|
|
if ctx.needs_input_grad[0]: |
|
|
|
mode = ctx.mode |
|
|
|
g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors |
|
|
|
dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3) |
|
|
|
dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2) |
|
|
|
s = dx.shape |
|
|
|
dx = dx.reshape(s[0], -1, 4, s[-2], s[-1]) |
|
|
|
dlow = dx[:, :, 0].contiguous() |
|
|
|
dhigh = dx[:, :, 1:].contiguous() |
|
|
|
return dlow, dhigh, None, None, None, None, None |