# -*- coding: utf-8 -*-
|
|
import pywt
|
|
import torch
|
|
import torch.nn as nn
|
|
from taming.modules.diffusionmodules.model import Decoder
|
|
|
|
from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int
|
|
|
|
|
|
class DecoderDWT(nn.Module):
|
|
def __init__(self, ddconfig, embed_dim):
|
|
super().__init__()
|
|
if ddconfig.out_ch != 12:
|
|
ddconfig.out_ch = 12
|
|
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
|
|
self.decoder = Decoder(**ddconfig)
|
|
self.idwt = DWTInverse(mode='zero', wave='db1')
|
|
|
|
def forward(self, x):
|
|
# x = self.post_quant_conv(x)
|
|
freq = self.decoder(x)
|
|
img = self.dwt_to_img(freq)
|
|
return img
|
|
|
|
def dwt_to_img(self, img):
|
|
b, c, h, w = img.size()
|
|
low = img[:, :3, :, :]
|
|
high = img[:, 3:, :, :].view(b, 3, 3, h, w)
|
|
return self.idwt((low, [high]))
|
|
|
|
|
|
class DWTInverse(nn.Module):
|
|
""" Performs a 2d DWT Inverse reconstruction of an image
|
|
|
|
Args:
|
|
wave (str or pywt.Wavelet): Which wavelet to use
|
|
C: deprecated, will be removed in future
|
|
"""
|
|
|
|
def __init__(self, wave='db1', mode='zero', trace_model=False):
|
|
super().__init__()
|
|
if isinstance(wave, str):
|
|
wave = pywt.Wavelet(wave)
|
|
if isinstance(wave, pywt.Wavelet):
|
|
g0_col, g1_col = wave.rec_lo, wave.rec_hi
|
|
g0_row, g1_row = g0_col, g1_col
|
|
else:
|
|
if len(wave) == 2:
|
|
g0_col, g1_col = wave[0], wave[1]
|
|
g0_row, g1_row = g0_col, g1_col
|
|
elif len(wave) == 4:
|
|
g0_col, g1_col = wave[0], wave[1]
|
|
g0_row, g1_row = wave[2], wave[3]
|
|
# Prepare the filters
|
|
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
|
|
self.register_buffer('g0_col', filts[0])
|
|
self.register_buffer('g1_col', filts[1])
|
|
self.register_buffer('g0_row', filts[2])
|
|
self.register_buffer('g1_row', filts[3])
|
|
self.mode = mode
|
|
self.trace_model = trace_model
|
|
|
|
def forward(self, coeffs):
|
|
"""
|
|
Args:
|
|
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
|
|
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
|
|
W_{in}')` and yh is a list of bandpass tensors of shape
|
|
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
|
|
the format returned by DWTForward
|
|
|
|
Returns:
|
|
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
|
|
|
|
Note:
|
|
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
|
|
downsampled shapes of the DWT pyramid.
|
|
|
|
Note:
|
|
Can have None for any of the highpass scales and will treat the
|
|
values as zeros (not in an efficient way though).
|
|
"""
|
|
yl, yh = coeffs
|
|
ll = yl
|
|
mode = mode_to_int(self.mode)
|
|
|
|
# Do a multilevel inverse transform
|
|
for h in yh[::-1]:
|
|
if h is None:
|
|
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
|
|
ll.shape[-1], device=ll.device)
|
|
|
|
# 'Unpad' added dimensions
|
|
if ll.shape[-2] > h.shape[-2]:
|
|
ll = ll[..., :-1, :]
|
|
if ll.shape[-1] > h.shape[-1]:
|
|
ll = ll[..., :-1]
|
|
if not self.trace_model:
|
|
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
|
|
else:
|
|
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
|
|
return ll
|