# -*- coding: utf-8 -*- import PIL import pytest import torch import torchvision.transforms as T import torchvision.transforms.functional as TF @pytest.mark.parametrize('target_image_size', [128, 192, 256]) def test_decode_vae(vae, sample_image, target_image_size): img = sample_image.copy() img = preprocess(img, target_image_size=target_image_size) with torch.no_grad(): img_seq = vae.get_codebook_indices(img) out_img = vae.decode(img_seq) assert out_img.shape == (1, 3, target_image_size, target_image_size) @pytest.mark.parametrize('target_image_size', [128, 192, 256]) def test_reconstruct_vae(vae, sample_image, target_image_size): img = sample_image.copy() with torch.no_grad(): x_vqgan = preprocess(img, target_image_size=target_image_size) output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), vae.model) assert output.shape == (1, 3, target_image_size, target_image_size) @pytest.mark.parametrize('target_image_size', [256]) def test_reconstruct_dwt_vae(dwt_vae, sample_image, target_image_size): img = sample_image.copy() with torch.no_grad(): x_vqgan = preprocess(img, target_image_size=target_image_size) output = reconstruct_with_vqgan(preprocess_vqgan(x_vqgan), dwt_vae.model) assert output.shape == (1, 3, target_image_size*2, target_image_size*2) def preprocess(img, target_image_size=256): s = min(img.size) if s < target_image_size: raise ValueError(f'min dim for image {s} < {target_image_size}') r = target_image_size / s s = (round(r * img.size[1]), round(r * img.size[0])) img = TF.resize(img, s, interpolation=PIL.Image.LANCZOS) img = TF.center_crop(img, output_size=2 * [target_image_size]) img = torch.unsqueeze(T.ToTensor()(img), 0) return img def preprocess_vqgan(x): x = 2.*x - 1. return x def reconstruct_with_vqgan(x, model): z, _, [_, _, _] = model.encode(x) print(f'VQGAN --- {model.__class__.__name__}: latent shape: {z.shape[2:]}') xrec = model.decode(z) return xrec