# -*- coding: utf-8 -*- from os.path import dirname, abspath, join import torch from huggingface_hub import hf_hub_url, cached_download from omegaconf import OmegaConf from .model import VQGanGumbelVAE def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'): # TODO config = OmegaConf.load(join(dirname(abspath(__file__)), 'vqgan.gumbelf8-sber.config.yml')) vae = VQGanGumbelVAE(config, dwt=dwt) if pretrained: repo_id = 'shonenkov/rudalle-utils' if dwt: filename = 'vqgan.gumbelf8-sber-dwt.model.ckpt' else: filename = 'vqgan.gumbelf8-sber.model.ckpt' cache_dir = join(cache_dir, 'vae') config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') vae.model.load_state_dict(checkpoint['state_dict'], strict=False) print('vae --> ready') return vae