# -*- coding: utf-8 -*- import io from os.path import abspath, dirname import PIL import pytest import requests from rudalle import get_tokenizer, get_rudalle_model, get_vae, get_realesrgan TEST_ROOT = dirname(abspath(__file__)) @pytest.fixture(scope='module') def realesrgan(): realesrgan = get_realesrgan('x2', device='cpu') yield realesrgan @pytest.fixture(scope='module') def vae(): vae = get_vae(pretrained=False) yield vae @pytest.fixture(scope='module') def dwt_vae(): vae = get_vae(pretrained=False, dwt=True) yield vae @pytest.fixture(scope='module') def yttm_tokenizer(): tokenizer = get_tokenizer() yield tokenizer @pytest.fixture(scope='module') def sample_image(): url = 'https://cdn.kqed.org/wp-content/uploads/sites/12/2013/12/rudolph.png' resp = requests.get(url) resp.raise_for_status() image = PIL.Image.open(io.BytesIO(resp.content)) yield image @pytest.fixture(scope='module') def small_dalle(): model = get_rudalle_model('small', pretrained=False, fp16=False, device='cpu') return model