You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

102 lines
3.6 KiB

# -*- coding: utf-8 -*-
import pywt
import torch
import torch.nn as nn
from taming.modules.diffusionmodules.model import Decoder
from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int
class DecoderDWT(nn.Module):
def __init__(self, ddconfig, embed_dim):
super().__init__()
if ddconfig.out_ch != 12:
ddconfig.out_ch = 12
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
self.decoder = Decoder(**ddconfig)
self.idwt = DWTInverse(mode='zero', wave='db1')
def forward(self, x):
# x = self.post_quant_conv(x)
freq = self.decoder(x)
img = self.dwt_to_img(freq)
return img
def dwt_to_img(self, img):
b, c, h, w = img.size()
low = img[:, :3, :, :]
high = img[:, 3:, :, :].view(b, 3, 3, h, w)
return self.idwt((low, [high]))
class DWTInverse(nn.Module):
""" Performs a 2d DWT Inverse reconstruction of an image
Args:
wave (str or pywt.Wavelet): Which wavelet to use
C: deprecated, will be removed in future
"""
def __init__(self, wave='db1', mode='zero', trace_model=False):
super().__init__()
if isinstance(wave, str):
wave = pywt.Wavelet(wave)
if isinstance(wave, pywt.Wavelet):
g0_col, g1_col = wave.rec_lo, wave.rec_hi
g0_row, g1_row = g0_col, g1_col
else:
if len(wave) == 2:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = g0_col, g1_col
elif len(wave) == 4:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = wave[2], wave[3]
# Prepare the filters
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
self.register_buffer('g0_col', filts[0])
self.register_buffer('g1_col', filts[1])
self.register_buffer('g0_row', filts[2])
self.register_buffer('g1_row', filts[3])
self.mode = mode
self.trace_model = trace_model
def forward(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
W_{in}')` and yh is a list of bandpass tensors of shape
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
the format returned by DWTForward
Returns:
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
Note:
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
downsampled shapes of the DWT pyramid.
Note:
Can have None for any of the highpass scales and will treat the
values as zeros (not in an efficient way though).
"""
yl, yh = coeffs
ll = yl
mode = mode_to_int(self.mode)
# Do a multilevel inverse transform
for h in yh[::-1]:
if h is None:
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
ll.shape[-1], device=ll.device)
# 'Unpad' added dimensions
if ll.shape[-2] > h.shape[-2]:
ll = ll[..., :-1, :]
if ll.shape[-1] > h.shape[-1]:
ll = ll[..., :-1]
if not self.trace_model:
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
else:
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
return ll