You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

58 lines
2.0 KiB

# -*- coding: utf-8 -*-
import PIL
import pytest
import torch
import torchvision.transforms as T
import torchvision.transforms.functional as TF
@pytest.mark.parametrize('target_image_size', [128, 192, 256])
def test_decode_vae(vae, sample_image, target_image_size):
img = sample_image.copy()
img = preprocess(img, target_image_size=target_image_size)
with torch.no_grad():
img_seq = vae.get_codebook_indices(img)
out_img = vae.decode(img_seq)
assert out_img.shape == (1, 3, target_image_size, target_image_size)
@pytest.mark.parametrize('target_image_size', [128, 192, 256])
def test_reconstruct_vae(vae, sample_image, target_image_size):
img = sample_image.copy()
with torch.no_grad():
x_vqgan = preprocess(img, target_image_size=target_image_size)
output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model)
assert output.shape == (1, 3, target_image_size, target_image_size)
@pytest.mark.parametrize('target_image_size', [256])
def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size):
img = sample_image.copy()
with torch.no_grad():
x_vqgan = preprocess(img, target_image_size=target_image_size)
output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model)
assert output.shape == (1, 3, target_image_size*2, target_image_size*2)
def preprocess(img, target_image_size=256):
s = min(img.size)
if s < target_image_size:
raise ValueError(f'min dim for image {s} < {target_image_size}')
r = target_image_size / s
s = (round(r * img.size[1]), round(r * img.size[0]))
img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS)
img = TF.center_crop(img, output_size=2 * [target_image_size])
img = torch.unsqueeze(T.ToTensor()(img), 0)
return img
def preprocess_vqgan(x):
x = 2.*x - 1.
return x
def reconstruct_with_vqgan(x, model):
z, _, [_, _, _] = model.encode(x)
print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}')
xrec = model.decode(z)
return xrec