# -*- coding: utf-8 -*- import pywt import torch import torch.nn as nn from taming.modules.diffusionmodules.model import Decoder from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int class DecoderDWT(nn.Module): def __init__(self, ddconfig, embed_dim): super().__init__() if ddconfig.out_ch != 12: ddconfig.out_ch = 12 self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1) self.decoder = Decoder(**ddconfig) self.idwt = DWTInverse(mode='zero', wave='db1') def forward(self, x): # x = self.post_quant_conv(x) freq = self.decoder(x) img = self.dwt_to_img(freq) return img def dwt_to_img(self, img): b, c, h, w = img.size() low = img[:, :3, :, :] high = img[:, 3:, :, :].view(b, 3, 3, h, w) return self.idwt((low, [high])) class DWTInverse(nn.Module): """ Performs a 2d DWT Inverse reconstruction of an image Args: wave (str or pywt.Wavelet): Which wavelet to use C: deprecated, will be removed in future """ def __init__(self, wave='db1', mode='zero', trace_model=False): super().__init__() if isinstance(wave, str): wave = pywt.Wavelet(wave) if isinstance(wave, pywt.Wavelet): g0_col, g1_col = wave.rec_lo, wave.rec_hi g0_row, g1_row = g0_col, g1_col else: if len(wave) == 2: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = g0_col, g1_col elif len(wave) == 4: g0_col, g1_col = wave[0], wave[1] g0_row, g1_row = wave[2], wave[3] # Prepare the filters filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row) self.register_buffer('g0_col', filts[0]) self.register_buffer('g1_col', filts[1]) self.register_buffer('g0_row', filts[2]) self.register_buffer('g1_row', filts[3]) self.mode = mode self.trace_model = trace_model def forward(self, coeffs): """ Args: coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where: yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}', W_{in}')` and yh is a list of bandpass tensors of shape :math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match the format returned by DWTForward Returns: Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})` Note: :math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly downsampled shapes of the DWT pyramid. Note: Can have None for any of the highpass scales and will treat the values as zeros (not in an efficient way though). """ yl, yh = coeffs ll = yl mode = mode_to_int(self.mode) # Do a multilevel inverse transform for h in yh[::-1]: if h is None: h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2], ll.shape[-1], device=ll.device) # 'Unpad' added dimensions if ll.shape[-2] > h.shape[-2]: ll = ll[..., :-1, :] if ll.shape[-1] > h.shape[-1]: ll = ll[..., :-1] if not self.trace_model: ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) else: ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode) return ll