# -*- coding: utf-8 -*-
|
|
"""
|
|
Useful utilities for testing the 2-D DTCWT with synthetic images
|
|
License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE
|
|
Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa
|
|
"""
|
|
|
|
import pywt
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn.functional as F
|
|
from torch.autograd import Function
|
|
|
|
|
|
def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
|
|
""" 1D synthesis filter bank of an image tensor
|
|
"""
|
|
C = lo.shape[1]
|
|
d = dim % 4
|
|
# If g0, g1 are not tensors, make them. If they are, then assume that they
|
|
# are in the right order
|
|
if not isinstance(g0, torch.Tensor):
|
|
g0 = torch.tensor(np.copy(np.array(g0).ravel()),
|
|
dtype=torch.float, device=lo.device)
|
|
if not isinstance(g1, torch.Tensor):
|
|
g1 = torch.tensor(np.copy(np.array(g1).ravel()),
|
|
dtype=torch.float, device=lo.device)
|
|
L = g0.numel()
|
|
shape = [1, 1, 1, 1]
|
|
shape[d] = L
|
|
N = 2*lo.shape[d]
|
|
# If g aren't in the right shape, make them so
|
|
if g0.shape != tuple(shape):
|
|
g0 = g0.reshape(*shape)
|
|
if g1.shape != tuple(shape):
|
|
g1 = g1.reshape(*shape)
|
|
|
|
s = (2, 1) if d == 2 else (1, 2)
|
|
g0 = torch.cat([g0]*C, dim=0)
|
|
g1 = torch.cat([g1]*C, dim=0)
|
|
if mode == 'per' or mode == 'periodization':
|
|
y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
|
|
F.conv_transpose2d(hi, g1, stride=s, groups=C)
|
|
if d == 2:
|
|
y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2]
|
|
y = y[:, :, :N]
|
|
else:
|
|
y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2]
|
|
y = y[:, :, :, :N]
|
|
y = roll(y, 1-L//2, dim=dim)
|
|
else:
|
|
if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
|
|
mode == 'periodic':
|
|
pad = (L-2, 0) if d == 2 else (0, L-2)
|
|
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
|
|
F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
|
|
else:
|
|
raise ValueError('Unkown pad type: {}'.format(mode))
|
|
|
|
return y
|
|
|
|
|
|
def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode):
|
|
mode = int_to_mode(mode)
|
|
|
|
lh, hl, hh = torch.unbind(highs, dim=2)
|
|
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
|
|
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
|
|
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
|
|
|
|
return y
|
|
|
|
|
|
def roll(x, n, dim, make_even=False):
|
|
if n < 0:
|
|
n = x.shape[dim] + n
|
|
|
|
if make_even and x.shape[dim] % 2 == 1:
|
|
end = 1
|
|
else:
|
|
end = 0
|
|
|
|
if dim == 0:
|
|
return torch.cat((x[-n:], x[:-n+end]), dim=0)
|
|
elif dim == 1:
|
|
return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1)
|
|
elif dim == 2 or dim == -2:
|
|
return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2)
|
|
elif dim == 3 or dim == -1:
|
|
return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3)
|
|
|
|
|
|
def int_to_mode(mode):
|
|
if mode == 0:
|
|
return 'zero'
|
|
elif mode == 1:
|
|
return 'symmetric'
|
|
elif mode == 2:
|
|
return 'periodization'
|
|
elif mode == 3:
|
|
return 'constant'
|
|
elif mode == 4:
|
|
return 'reflect'
|
|
elif mode == 5:
|
|
return 'replicate'
|
|
elif mode == 6:
|
|
return 'periodic'
|
|
else:
|
|
raise ValueError('Unkown pad type: {}'.format(mode))
|
|
|
|
|
|
def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
|
|
"""
|
|
Prepares the filters to be of the right form for the sfb2d function. In
|
|
particular, makes the tensors the right shape. It does not mirror image them
|
|
as as sfb2d uses conv2d_transpose which acts like normal convolution.
|
|
Inputs:
|
|
g0_col (array-like): low pass column filter bank
|
|
g1_col (array-like): high pass column filter bank
|
|
g0_row (array-like): low pass row filter bank. If none, will assume the
|
|
same as column filter
|
|
g1_row (array-like): high pass row filter bank. If none, will assume the
|
|
same as column filter
|
|
device: which device to put the tensors on to
|
|
Returns:
|
|
(g0_col, g1_col, g0_row, g1_row)
|
|
"""
|
|
g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
|
|
if g0_row is None:
|
|
g0_row, g1_row = g0_col, g1_col
|
|
else:
|
|
g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)
|
|
|
|
g0_col = g0_col.reshape((1, 1, -1, 1))
|
|
g1_col = g1_col.reshape((1, 1, -1, 1))
|
|
g0_row = g0_row.reshape((1, 1, 1, -1))
|
|
g1_row = g1_row.reshape((1, 1, 1, -1))
|
|
|
|
return g0_col, g1_col, g0_row, g1_row
|
|
|
|
|
|
def prep_filt_sfb1d(g0, g1, device=None):
|
|
"""
|
|
Prepares the filters to be of the right form for the sfb1d function. In
|
|
particular, makes the tensors the right shape. It does not mirror image them
|
|
as as sfb2d uses conv2d_transpose which acts like normal convolution.
|
|
Inputs:
|
|
g0 (array-like): low pass filter bank
|
|
g1 (array-like): high pass filter bank
|
|
device: which device to put the tensors on to
|
|
Returns:
|
|
(g0, g1)
|
|
"""
|
|
g0 = np.array(g0).ravel()
|
|
g1 = np.array(g1).ravel()
|
|
t = torch.get_default_dtype()
|
|
g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
|
|
g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))
|
|
|
|
return g0, g1
|
|
|
|
|
|
def mode_to_int(mode):
|
|
if mode == 'zero':
|
|
return 0
|
|
elif mode == 'symmetric':
|
|
return 1
|
|
elif mode == 'per' or mode == 'periodization':
|
|
return 2
|
|
elif mode == 'constant':
|
|
return 3
|
|
elif mode == 'reflect':
|
|
return 4
|
|
elif mode == 'replicate':
|
|
return 5
|
|
elif mode == 'periodic':
|
|
return 6
|
|
else:
|
|
raise ValueError('Unkown pad type: {}'.format(mode))
|
|
|
|
|
|
def afb1d(x, h0, h1, mode='zero', dim=-1):
|
|
""" 1D analysis filter bank (along one dimension only) of an image
|
|
Inputs:
|
|
x (tensor): 4D input with the last two dimensions the spatial input
|
|
h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
|
|
h, 1) or (1, 1, 1, w)
|
|
h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
|
|
h, 1) or (1, 1, 1, w)
|
|
mode (str): padding method
|
|
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
|
|
column filtering but filters across the rows). d=3 is for a
|
|
horizontal filter, (called row filtering but filters across the
|
|
columns).
|
|
Returns:
|
|
lohi: lowpass and highpass subbands concatenated along the channel
|
|
dimension
|
|
"""
|
|
C = x.shape[1]
|
|
# Convert the dim to positive
|
|
d = dim % 4
|
|
s = (2, 1) if d == 2 else (1, 2)
|
|
N = x.shape[d]
|
|
# If h0, h1 are not tensors, make them. If they are, then assume that they
|
|
# are in the right order
|
|
if not isinstance(h0, torch.Tensor):
|
|
h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
|
|
dtype=torch.float, device=x.device)
|
|
if not isinstance(h1, torch.Tensor):
|
|
h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
|
|
dtype=torch.float, device=x.device)
|
|
L = h0.numel()
|
|
L2 = L // 2
|
|
shape = [1, 1, 1, 1]
|
|
shape[d] = L
|
|
# If h aren't in the right shape, make them so
|
|
if h0.shape != tuple(shape):
|
|
h0 = h0.reshape(*shape)
|
|
if h1.shape != tuple(shape):
|
|
h1 = h1.reshape(*shape)
|
|
h = torch.cat([h0, h1] * C, dim=0)
|
|
|
|
if mode == 'per' or mode == 'periodization':
|
|
if x.shape[dim] % 2 == 1:
|
|
if d == 2:
|
|
x = torch.cat((x, x[:, :, -1:]), dim=2)
|
|
else:
|
|
x = torch.cat((x, x[:, :, :, -1:]), dim=3)
|
|
N += 1
|
|
x = roll(x, -L2, dim=d)
|
|
pad = (L-1, 0) if d == 2 else (0, L-1)
|
|
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
|
|
N2 = N//2
|
|
if d == 2:
|
|
lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2]
|
|
lohi = lohi[:, :, :N2]
|
|
else:
|
|
lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2]
|
|
lohi = lohi[:, :, :, :N2]
|
|
else:
|
|
# Calculate the pad size
|
|
outsize = pywt.dwt_coeff_len(N, L, mode=mode)
|
|
p = 2 * (outsize - 1) - N + L
|
|
if mode == 'zero':
|
|
# Sadly, pytorch only allows for same padding before and after, if
|
|
# we need to do more padding after for odd length signals, have to
|
|
# prepad
|
|
if p % 2 == 1:
|
|
pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
|
|
x = F.pad(x, pad)
|
|
pad = (p//2, 0) if d == 2 else (0, p//2)
|
|
# Calculate the high and lowpass
|
|
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
|
|
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
|
|
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
|
|
x = mypad(x, pad=pad, mode=mode)
|
|
lohi = F.conv2d(x, h, stride=s, groups=C)
|
|
else:
|
|
raise ValueError('Unkown pad type: {}'.format(mode))
|
|
|
|
return lohi
|
|
|
|
|
|
def mypad(x, pad, mode='constant', value=0):
|
|
""" Function to do numpy like padding on tensors. Only works for 2-D
|
|
padding.
|
|
Inputs:
|
|
x (tensor): tensor to pad
|
|
pad (tuple): tuple of (left, right, top, bottom) pad sizes
|
|
mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or
|
|
'zero'. The padding technique.
|
|
"""
|
|
if mode == 'symmetric':
|
|
# Vertical only
|
|
if pad[0] == 0 and pad[1] == 0:
|
|
m1, m2 = pad[2], pad[3]
|
|
l = x.shape[-2] # noqa
|
|
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
|
|
return x[:, :, xe]
|
|
# horizontal only
|
|
elif pad[2] == 0 and pad[3] == 0:
|
|
m1, m2 = pad[0], pad[1]
|
|
l = x.shape[-1] # noqa
|
|
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
|
|
return x[:, :, :, xe]
|
|
# Both
|
|
else:
|
|
m1, m2 = pad[0], pad[1]
|
|
l1 = x.shape[-1]
|
|
xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5)
|
|
m1, m2 = pad[2], pad[3]
|
|
l2 = x.shape[-2]
|
|
xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5)
|
|
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
|
|
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
|
|
return x[:, :, i, j]
|
|
elif mode == 'periodic':
|
|
# Vertical only
|
|
if pad[0] == 0 and pad[1] == 0:
|
|
xe = np.arange(x.shape[-2])
|
|
xe = np.pad(xe, (pad[2], pad[3]), mode='wrap')
|
|
return x[:, :, xe]
|
|
# Horizontal only
|
|
elif pad[2] == 0 and pad[3] == 0:
|
|
xe = np.arange(x.shape[-1])
|
|
xe = np.pad(xe, (pad[0], pad[1]), mode='wrap')
|
|
return x[:, :, :, xe]
|
|
# Both
|
|
else:
|
|
xe_col = np.arange(x.shape[-2])
|
|
xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap')
|
|
xe_row = np.arange(x.shape[-1])
|
|
xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap')
|
|
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
|
|
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
|
|
return x[:, :, i, j]
|
|
|
|
elif mode == 'constant' or mode == 'reflect' or mode == 'replicate':
|
|
return F.pad(x, pad, mode, value)
|
|
elif mode == 'zero':
|
|
return F.pad(x, pad)
|
|
else:
|
|
raise ValueError('Unkown pad type: {}'.format(mode))
|
|
|
|
|
|
def reflect(x, minx, maxx):
|
|
"""Reflect the values in matrix *x* about the scalar values *minx* and
|
|
*maxx*. Hence a vector *x* containing a long linearly increasing series is
|
|
converted into a waveform which ramps linearly up and down between *minx*
|
|
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
|
|
0.5), the ramps will have repeated max and min samples.
|
|
.. codeauthor:: Rich Wareham <[email protected]>, Aug 2013
|
|
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
|
|
"""
|
|
x = np.asanyarray(x)
|
|
rng = maxx - minx
|
|
rng_by_2 = 2 * rng
|
|
mod = np.fmod(x - minx, rng_by_2)
|
|
normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
|
|
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
|
|
return np.array(out, dtype=x.dtype)
|
|
|
|
|
|
class SFB2D(Function):
|
|
""" Does a single level 2d wavelet decomposition of an input. Does separate
|
|
row and column filtering by two calls to
|
|
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
|
|
Needs to have the tensors in the right form. Because this function defines
|
|
its own backward pass, saves on memory by not having to save the input
|
|
tensors.
|
|
Inputs:
|
|
x (torch.Tensor): Input to decompose
|
|
h0_row: row lowpass
|
|
h1_row: row highpass
|
|
h0_col: col lowpass
|
|
h1_col: col highpass
|
|
mode (int): use mode_to_int to get the int code here
|
|
We encode the mode as an integer rather than a string as gradcheck causes an
|
|
error when a string is provided.
|
|
Returns:
|
|
y: Tensor of shape (N, C*4, H, W)
|
|
"""
|
|
@staticmethod
|
|
def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
|
|
mode = int_to_mode(mode)
|
|
ctx.mode = mode
|
|
ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)
|
|
|
|
lh, hl, hh = torch.unbind(highs, dim=2)
|
|
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
|
|
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
|
|
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
|
|
return y
|
|
|
|
@staticmethod
|
|
def backward(ctx, dy):
|
|
dlow, dhigh = None, None
|
|
if ctx.needs_input_grad[0]:
|
|
mode = ctx.mode
|
|
g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
|
|
dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3)
|
|
dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2)
|
|
s = dx.shape
|
|
dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
|
|
dlow = dx[:, :, 0].contiguous()
|
|
dhigh = dx[:, :, 1:].contiguous()
|
|
return dlow, dhigh, None, None, None, None, None
|