|
|
|
@ -22,6 +22,9 @@ def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'): |
|
|
|
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename) |
|
|
|
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename) |
|
|
|
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu') |
|
|
|
vae.model.load_state_dict(checkpoint['state_dict'], strict=False) |
|
|
|
if dwt: |
|
|
|
vae.load_state_dict(checkpoint['state_dict']) |
|
|
|
else: |
|
|
|
vae.model.load_state_dict(checkpoint['state_dict'], strict=False) |
|
|
|
print('vae --> ready') |
|
|
|
return vae |