Browse Source

add dwt vae

feature/dwt_vae
shonenkov 4 years ago
parent
commit
c753cd624b
9 changed files with 525 additions and 9 deletions
  1. +2
    -1
      README.md
  2. +1
    -0
      requirements.txt
  3. +1
    -1
      rudalle/__init__.py
  4. +6
    -3
      rudalle/vae/__init__.py
  5. +102
    -0
      rudalle/vae/decoder_dwt.py
  6. +11
    -4
      rudalle/vae/model.py
  7. +387
    -0
      rudalle/vae/pytorch_wavelets_utils.py
  8. +6
    -0
      tests/conftest.py
  9. +9
    -0
      tests/test_vae.py

+ 2
- 1
README.md View File

@ -7,7 +7,7 @@
[![pre-commit.ci status](https://results.pre-commit.ci/badge/github/sberbank-ai/ru-dalle/master.svg)](https://results.pre-commit.ci/latest/github/sberbank-ai/ru-dalle/master)
```
pip install rudalle==0.0.1rc6
pip install rudalle==0.0.1rc7
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)
@ -92,6 +92,7 @@ skyes = [red_sky, sunny_sky, cloudy_sky, night_sky]
### 🚀 Contributors 🚀
- [@bes](https://github.com/bes-dev) shared [great idea and realization with IDWT](https://github.com/bes-dev/vqvae_dwt_distiller.pytorch) for decoding images with higher quality 512x512! 😈💪
- [@neverix](https://www.kaggle.com/neverix) thanks a lot for contributing for speed up of inference
- [@Igor Pavlov](https://github.com/boomb0om) trained model and prepared code with [super-resolution](https://github.com/boomb0om/Real-ESRGAN-colab)
- [@oriBetelgeuse](https://github.com/oriBetelgeuse) thanks a lot for easy API of generation using image prompt


+ 1
- 0
requirements.txt View File

@ -4,6 +4,7 @@ transformers~=4.10.2
youtokentome~=1.0.6
omegaconf>=2.0.0
einops~=0.3.2
PyWavelets==1.1.1
torch
torchvision
matplotlib

+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc6'
__version__ = '0.0.1-rc7'

+ 6
- 3
rudalle/vae/__init__.py View File

@ -8,13 +8,16 @@ from omegaconf import OmegaConf
from .model import VQGanGumbelVAE
def get_vae(pretrained=True, cache_dir='/tmp/rudalle'):
def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
# TODO
config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml'))
vae = VQGanGumbelVAE(config)
vae = VQGanGumbelVAE(config, dwt=dwt)
if pretrained:
repo_id = 'shonenkov/rudalle-utils'
filename = 'vqgan.gumbelf8-sber.model.ckpt'
if dwt:
filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt'
else:
filename = 'vqgan.gumbelf8-sber.model.ckpt'
cache_dir = join(cache_dir, 'vae')
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)


+ 102
- 0
rudalle/vae/decoder_dwt.py View File

@ -0,0 +1,102 @@
# -*- coding: utf-8 -*-
import pywt
import torch
import torch.nn as nn
from taming.modules.diffusionmodules.model import Decoder
from .pytorch_wavelets_utils import SFB2D, _SFB2D, prep_filt_sfb2d, mode_to_int
class DecoderDWT(nn.Module):
def __init__(self, ddconfig, embed_dim):
super().__init__()
if ddconfig.out_ch != 12:
ddconfig.out_ch = 12
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
self.decoder = Decoder(**ddconfig)
self.idwt = DWTInverse(mode='zero', wave='db1')
def forward(self, x):
# x = self.post_quant_conv(x)
freq = self.decoder(x)
img = self.dwt_to_img(freq)
return img
def dwt_to_img(self, img):
b, c, h, w = img.size()
low = img[:, :3, :, :]
high = img[:, 3:, :, :].view(b, 3, 3, h, w)
return self.idwt((low, [high]))
class DWTInverse(nn.Module):
""" Performs a 2d DWT Inverse reconstruction of an image
Args:
wave (str or pywt.Wavelet): Which wavelet to use
C: deprecated, will be removed in future
"""
def __init__(self, wave='db1', mode='zero', trace_model=False):
super().__init__()
if isinstance(wave, str):
wave = pywt.Wavelet(wave)
if isinstance(wave, pywt.Wavelet):
g0_col, g1_col = wave.rec_lo, wave.rec_hi
g0_row, g1_row = g0_col, g1_col
else:
if len(wave) == 2:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = g0_col, g1_col
elif len(wave) == 4:
g0_col, g1_col = wave[0], wave[1]
g0_row, g1_row = wave[2], wave[3]
# Prepare the filters
filts = prep_filt_sfb2d(g0_col, g1_col, g0_row, g1_row)
self.register_buffer('g0_col', filts[0])
self.register_buffer('g1_col', filts[1])
self.register_buffer('g0_row', filts[2])
self.register_buffer('g1_row', filts[3])
self.mode = mode
self.trace_model = trace_model
def forward(self, coeffs):
"""
Args:
coeffs (yl, yh): tuple of lowpass and bandpass coefficients, where:
yl is a lowpass tensor of shape :math:`(N, C_{in}, H_{in}',
W_{in}')` and yh is a list of bandpass tensors of shape
:math:`list(N, C_{in}, 3, H_{in}'', W_{in}'')`. I.e. should match
the format returned by DWTForward
Returns:
Reconstructed input of shape :math:`(N, C_{in}, H_{in}, W_{in})`
Note:
:math:`H_{in}', W_{in}', H_{in}'', W_{in}''` denote the correctly
downsampled shapes of the DWT pyramid.
Note:
Can have None for any of the highpass scales and will treat the
values as zeros (not in an efficient way though).
"""
yl, yh = coeffs
ll = yl
mode = mode_to_int(self.mode)
# Do a multilevel inverse transform
for h in yh[::-1]:
if h is None:
h = torch.zeros(ll.shape[0], ll.shape[1], 3, ll.shape[-2],
ll.shape[-1], device=ll.device)
# 'Unpad' added dimensions
if ll.shape[-2] > h.shape[-2]:
ll = ll[..., :-1, :]
if ll.shape[-1] > h.shape[-1]:
ll = ll[..., :-1]
if not self.trace_model:
ll = SFB2D.apply(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
else:
ll = _SFB2D(ll, h, self.g0_col, self.g1_col, self.g0_row, self.g1_row, mode)
return ll

+ 11
- 4
rudalle/vae/model.py View File

@ -8,16 +8,19 @@ from torch import einsum
from einops import rearrange
from taming.modules.diffusionmodules.model import Encoder, Decoder
from .decoder_dwt import DecoderDWT
class VQGanGumbelVAE(torch.nn.Module):
def __init__(self, config):
def __init__(self, config, dwt=False):
super().__init__()
model = GumbelVQ(
ddconfig=config.model.params.ddconfig,
n_embed=config.model.params.n_embed,
embed_dim=config.model.params.embed_dim,
kl_weight=config.model.params.kl_weight,
dwt=dwt,
)
self.model = model
self.num_layers = int(log(config.model.params.ddconfig.attn_resolutions[0]) / log(2))
@ -79,11 +82,12 @@ class GumbelQuantize(nn.Module):
class GumbelVQ(nn.Module):
def __init__(self, ddconfig, n_embed, embed_dim, kl_weight=1e-8):
def __init__(self, ddconfig, n_embed, embed_dim, dwt=False, kl_weight=1e-8):
super().__init__()
z_channels = ddconfig['z_channels']
self.dwt = dwt
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.decoder = DecoderDWT(ddconfig, embed_dim) if dwt else Decoder(**ddconfig)
self.quantize = GumbelQuantize(z_channels, embed_dim, n_embed=n_embed, kl_weight=kl_weight, temp_init=1.0)
self.quant_conv = torch.nn.Conv2d(ddconfig['z_channels'], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig['z_channels'], 1)
@ -95,6 +99,9 @@ class GumbelVQ(nn.Module):
return quant, emb_loss, info
def decode(self, quant):
quant = self.post_quant_conv(quant)
if self.dwt:
quant = self.decoder.post_quant_conv(quant)
else:
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec

+ 387
- 0
rudalle/vae/pytorch_wavelets_utils.py View File

@ -0,0 +1,387 @@
# -*- coding: utf-8 -*-
"""
Useful utilities for testing the 2-D DTCWT with synthetic images
License: https://github.com/fbcotter/pytorch_wavelets/blob/master/LICENSE
Source: https://github.com/fbcotter/pytorch_wavelets/blob/31d6ac1b51b08f811a6a70eb7b3440f106009da0/pytorch_wavelets/dwt/lowlevel.py # noqa
"""
import pywt
import torch
import numpy as np
import torch.nn.functional as F
from torch.autograd import Function
def sfb1d(lo, hi, g0, g1, mode='zero', dim=-1):
""" 1D synthesis filter bank of an image tensor
"""
C = lo.shape[1]
d = dim % 4
# If g0, g1 are not tensors, make them. If they are, then assume that they
# are in the right order
if not isinstance(g0, torch.Tensor):
g0 = torch.tensor(np.copy(np.array(g0).ravel()),
dtype=torch.float, device=lo.device)
if not isinstance(g1, torch.Tensor):
g1 = torch.tensor(np.copy(np.array(g1).ravel()),
dtype=torch.float, device=lo.device)
L = g0.numel()
shape = [1, 1, 1, 1]
shape[d] = L
N = 2*lo.shape[d]
# If g aren't in the right shape, make them so
if g0.shape != tuple(shape):
g0 = g0.reshape(*shape)
if g1.shape != tuple(shape):
g1 = g1.reshape(*shape)
s = (2, 1) if d == 2 else (1, 2)
g0 = torch.cat([g0]*C, dim=0)
g1 = torch.cat([g1]*C, dim=0)
if mode == 'per' or mode == 'periodization':
y = F.conv_transpose2d(lo, g0, stride=s, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, groups=C)
if d == 2:
y[:, :, :L-2] = y[:, :, :L-2] + y[:, :, N:N+L-2]
y = y[:, :, :N]
else:
y[:, :, :, :L-2] = y[:, :, :, :L-2] + y[:, :, :, N:N+L-2]
y = y[:, :, :, :N]
y = roll(y, 1-L//2, dim=dim)
else:
if mode == 'zero' or mode == 'symmetric' or mode == 'reflect' or \
mode == 'periodic':
pad = (L-2, 0) if d == 2 else (0, L-2)
y = F.conv_transpose2d(lo, g0, stride=s, padding=pad, groups=C) + \
F.conv_transpose2d(hi, g1, stride=s, padding=pad, groups=C)
else:
raise ValueError('Unkown pad type: {}'.format(mode))
return y
def _SFB2D(low, highs, g0_row, g1_row, g0_col, g1_col, mode):
mode = int_to_mode(mode)
lh, hl, hh = torch.unbind(highs, dim=2)
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
return y
def roll(x, n, dim, make_even=False):
if n < 0:
n = x.shape[dim] + n
if make_even and x.shape[dim] % 2 == 1:
end = 1
else:
end = 0
if dim == 0:
return torch.cat((x[-n:], x[:-n+end]), dim=0)
elif dim == 1:
return torch.cat((x[:, -n:], x[:, :-n+end]), dim=1)
elif dim == 2 or dim == -2:
return torch.cat((x[:, :, -n:], x[:, :, :-n+end]), dim=2)
elif dim == 3 or dim == -1:
return torch.cat((x[:, :, :, -n:], x[:, :, :, :-n+end]), dim=3)
def int_to_mode(mode):
if mode == 0:
return 'zero'
elif mode == 1:
return 'symmetric'
elif mode == 2:
return 'periodization'
elif mode == 3:
return 'constant'
elif mode == 4:
return 'reflect'
elif mode == 5:
return 'replicate'
elif mode == 6:
return 'periodic'
else:
raise ValueError('Unkown pad type: {}'.format(mode))
def prep_filt_sfb2d(g0_col, g1_col, g0_row=None, g1_row=None, device=None):
"""
Prepares the filters to be of the right form for the sfb2d function. In
particular, makes the tensors the right shape. It does not mirror image them
as as sfb2d uses conv2d_transpose which acts like normal convolution.
Inputs:
g0_col (array-like): low pass column filter bank
g1_col (array-like): high pass column filter bank
g0_row (array-like): low pass row filter bank. If none, will assume the
same as column filter
g1_row (array-like): high pass row filter bank. If none, will assume the
same as column filter
device: which device to put the tensors on to
Returns:
(g0_col, g1_col, g0_row, g1_row)
"""
g0_col, g1_col = prep_filt_sfb1d(g0_col, g1_col, device)
if g0_row is None:
g0_row, g1_row = g0_col, g1_col
else:
g0_row, g1_row = prep_filt_sfb1d(g0_row, g1_row, device)
g0_col = g0_col.reshape((1, 1, -1, 1))
g1_col = g1_col.reshape((1, 1, -1, 1))
g0_row = g0_row.reshape((1, 1, 1, -1))
g1_row = g1_row.reshape((1, 1, 1, -1))
return g0_col, g1_col, g0_row, g1_row
def prep_filt_sfb1d(g0, g1, device=None):
"""
Prepares the filters to be of the right form for the sfb1d function. In
particular, makes the tensors the right shape. It does not mirror image them
as as sfb2d uses conv2d_transpose which acts like normal convolution.
Inputs:
g0 (array-like): low pass filter bank
g1 (array-like): high pass filter bank
device: which device to put the tensors on to
Returns:
(g0, g1)
"""
g0 = np.array(g0).ravel()
g1 = np.array(g1).ravel()
t = torch.get_default_dtype()
g0 = torch.tensor(g0, device=device, dtype=t).reshape((1, 1, -1))
g1 = torch.tensor(g1, device=device, dtype=t).reshape((1, 1, -1))
return g0, g1
def mode_to_int(mode):
if mode == 'zero':
return 0
elif mode == 'symmetric':
return 1
elif mode == 'per' or mode == 'periodization':
return 2
elif mode == 'constant':
return 3
elif mode == 'reflect':
return 4
elif mode == 'replicate':
return 5
elif mode == 'periodic':
return 6
else:
raise ValueError('Unkown pad type: {}'.format(mode))
def afb1d(x, h0, h1, mode='zero', dim=-1):
""" 1D analysis filter bank (along one dimension only) of an image
Inputs:
x (tensor): 4D input with the last two dimensions the spatial input
h0 (tensor): 4D input for the lowpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
h1 (tensor): 4D input for the highpass filter. Should have shape (1, 1,
h, 1) or (1, 1, 1, w)
mode (str): padding method
dim (int) - dimension of filtering. d=2 is for a vertical filter (called
column filtering but filters across the rows). d=3 is for a
horizontal filter, (called row filtering but filters across the
columns).
Returns:
lohi: lowpass and highpass subbands concatenated along the channel
dimension
"""
C = x.shape[1]
# Convert the dim to positive
d = dim % 4
s = (2, 1) if d == 2 else (1, 2)
N = x.shape[d]
# If h0, h1 are not tensors, make them. If they are, then assume that they
# are in the right order
if not isinstance(h0, torch.Tensor):
h0 = torch.tensor(np.copy(np.array(h0).ravel()[::-1]),
dtype=torch.float, device=x.device)
if not isinstance(h1, torch.Tensor):
h1 = torch.tensor(np.copy(np.array(h1).ravel()[::-1]),
dtype=torch.float, device=x.device)
L = h0.numel()
L2 = L // 2
shape = [1, 1, 1, 1]
shape[d] = L
# If h aren't in the right shape, make them so
if h0.shape != tuple(shape):
h0 = h0.reshape(*shape)
if h1.shape != tuple(shape):
h1 = h1.reshape(*shape)
h = torch.cat([h0, h1] * C, dim=0)
if mode == 'per' or mode == 'periodization':
if x.shape[dim] % 2 == 1:
if d == 2:
x = torch.cat((x, x[:, :, -1:]), dim=2)
else:
x = torch.cat((x, x[:, :, :, -1:]), dim=3)
N += 1
x = roll(x, -L2, dim=d)
pad = (L-1, 0) if d == 2 else (0, L-1)
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
N2 = N//2
if d == 2:
lohi[:, :, :L2] = lohi[:, :, :L2] + lohi[:, :, N2:N2+L2]
lohi = lohi[:, :, :N2]
else:
lohi[:, :, :, :L2] = lohi[:, :, :, :L2] + lohi[:, :, :, N2:N2+L2]
lohi = lohi[:, :, :, :N2]
else:
# Calculate the pad size
outsize = pywt.dwt_coeff_len(N, L, mode=mode)
p = 2 * (outsize - 1) - N + L
if mode == 'zero':
# Sadly, pytorch only allows for same padding before and after, if
# we need to do more padding after for odd length signals, have to
# prepad
if p % 2 == 1:
pad = (0, 0, 0, 1) if d == 2 else (0, 1, 0, 0)
x = F.pad(x, pad)
pad = (p//2, 0) if d == 2 else (0, p//2)
# Calculate the high and lowpass
lohi = F.conv2d(x, h, padding=pad, stride=s, groups=C)
elif mode == 'symmetric' or mode == 'reflect' or mode == 'periodic':
pad = (0, 0, p//2, (p+1)//2) if d == 2 else (p//2, (p+1)//2, 0, 0)
x = mypad(x, pad=pad, mode=mode)
lohi = F.conv2d(x, h, stride=s, groups=C)
else:
raise ValueError('Unkown pad type: {}'.format(mode))
return lohi
def mypad(x, pad, mode='constant', value=0):
""" Function to do numpy like padding on tensors. Only works for 2-D
padding.
Inputs:
x (tensor): tensor to pad
pad (tuple): tuple of (left, right, top, bottom) pad sizes
mode (str): 'symmetric', 'wrap', 'constant, 'reflect', 'replicate', or
'zero'. The padding technique.
"""
if mode == 'symmetric':
# Vertical only
if pad[0] == 0 and pad[1] == 0:
m1, m2 = pad[2], pad[3]
l = x.shape[-2] # noqa
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
return x[:, :, xe]
# horizontal only
elif pad[2] == 0 and pad[3] == 0:
m1, m2 = pad[0], pad[1]
l = x.shape[-1] # noqa
xe = reflect(np.arange(-m1, l+m2, dtype='int32'), -0.5, l-0.5)
return x[:, :, :, xe]
# Both
else:
m1, m2 = pad[0], pad[1]
l1 = x.shape[-1]
xe_row = reflect(np.arange(-m1, l1+m2, dtype='int32'), -0.5, l1-0.5)
m1, m2 = pad[2], pad[3]
l2 = x.shape[-2]
xe_col = reflect(np.arange(-m1, l2+m2, dtype='int32'), -0.5, l2-0.5)
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
return x[:, :, i, j]
elif mode == 'periodic':
# Vertical only
if pad[0] == 0 and pad[1] == 0:
xe = np.arange(x.shape[-2])
xe = np.pad(xe, (pad[2], pad[3]), mode='wrap')
return x[:, :, xe]
# Horizontal only
elif pad[2] == 0 and pad[3] == 0:
xe = np.arange(x.shape[-1])
xe = np.pad(xe, (pad[0], pad[1]), mode='wrap')
return x[:, :, :, xe]
# Both
else:
xe_col = np.arange(x.shape[-2])
xe_col = np.pad(xe_col, (pad[2], pad[3]), mode='wrap')
xe_row = np.arange(x.shape[-1])
xe_row = np.pad(xe_row, (pad[0], pad[1]), mode='wrap')
i = np.outer(xe_col, np.ones(xe_row.shape[0]))
j = np.outer(np.ones(xe_col.shape[0]), xe_row)
return x[:, :, i, j]
elif mode == 'constant' or mode == 'reflect' or mode == 'replicate':
return F.pad(x, pad, mode, value)
elif mode == 'zero':
return F.pad(x, pad)
else:
raise ValueError('Unkown pad type: {}'.format(mode))
def reflect(x, minx, maxx):
"""Reflect the values in matrix *x* about the scalar values *minx* and
*maxx*. Hence a vector *x* containing a long linearly increasing series is
converted into a waveform which ramps linearly up and down between *minx*
and *maxx*. If *x* contains integers and *minx* and *maxx* are (integers +
0.5), the ramps will have repeated max and min samples.
.. codeauthor:: Rich Wareham <rjw57@cantab.net>, Aug 2013
.. codeauthor:: Nick Kingsbury, Cambridge University, January 1999.
"""
x = np.asanyarray(x)
rng = maxx - minx
rng_by_2 = 2 * rng
mod = np.fmod(x - minx, rng_by_2)
normed_mod = np.where(mod < 0, mod + rng_by_2, mod)
out = np.where(normed_mod >= rng, rng_by_2 - normed_mod, normed_mod) + minx
return np.array(out, dtype=x.dtype)
class SFB2D(Function):
""" Does a single level 2d wavelet decomposition of an input. Does separate
row and column filtering by two calls to
:py:func:`pytorch_wavelets.dwt.lowlevel.afb1d`
Needs to have the tensors in the right form. Because this function defines
its own backward pass, saves on memory by not having to save the input
tensors.
Inputs:
x (torch.Tensor): Input to decompose
h0_row: row lowpass
h1_row: row highpass
h0_col: col lowpass
h1_col: col highpass
mode (int): use mode_to_int to get the int code here
We encode the mode as an integer rather than a string as gradcheck causes an
error when a string is provided.
Returns:
y: Tensor of shape (N, C*4, H, W)
"""
@staticmethod
def forward(ctx, low, highs, g0_row, g1_row, g0_col, g1_col, mode):
mode = int_to_mode(mode)
ctx.mode = mode
ctx.save_for_backward(g0_row, g1_row, g0_col, g1_col)
lh, hl, hh = torch.unbind(highs, dim=2)
lo = sfb1d(low, lh, g0_col, g1_col, mode=mode, dim=2)
hi = sfb1d(hl, hh, g0_col, g1_col, mode=mode, dim=2)
y = sfb1d(lo, hi, g0_row, g1_row, mode=mode, dim=3)
return y
@staticmethod
def backward(ctx, dy):
dlow, dhigh = None, None
if ctx.needs_input_grad[0]:
mode = ctx.mode
g0_row, g1_row, g0_col, g1_col = ctx.saved_tensors
dx = afb1d(dy, g0_row, g1_row, mode=mode, dim=3)
dx = afb1d(dx, g0_col, g1_col, mode=mode, dim=2)
s = dx.shape
dx = dx.reshape(s[0], -1, 4, s[-2], s[-1])
dlow = dx[:, :, 0].contiguous()
dhigh = dx[:, :, 1:].contiguous()
return dlow, dhigh, None, None, None, None, None

+ 6
- 0
tests/conftest.py View File

@ -24,6 +24,12 @@ def vae():
yield vae
@pytest.fixture(scope='module')
def dwt_vae():
vae = get_vae(pretrained=False, dwt=True)
yield vae
@pytest.fixture(scope='module')
def yttm_tokenizer():
tokenizer = get_tokenizer()


+ 9
- 0
tests/test_vae.py View File

@ -25,6 +25,15 @@ def test_reconstruct_vae(vae, sample_image, target_image_size):
assert output.shape == (1, 3, target_image_size, target_image_size)
@pytest.mark.parametrize('target_image_size', [256])
def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size):
img = sample_image.copy()
with torch.no_grad():
x_vqgan = preprocess(img, target_image_size=target_image_size)
output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model)
assert output.shape == (1, 3, target_image_size*2, target_image_size*2)
def preprocess(img, target_image_size=256):
s = min(img.size)
if s < target_image_size:


Loading…
Cancel
Save