Browse Source

add dwt checkpoint load

feature/dwt_vae
shonenkov 4 years ago
parent
commit
d5d131b776
1 changed files with 4 additions and 1 deletions
  1. +4
    -1
      rudalle/vae/__init__.py

+ 4
- 1
rudalle/vae/__init__.py View File

@ -22,6 +22,9 @@ def get_vae(pretrained=True, dwt=False, cache_dir='/tmp/rudalle'):
config_file_url = hf_hub_url(repo_id=repo_id, filename=filename)
cached_download(config_file_url, cache_dir=cache_dir, force_filename=filename)
checkpoint = torch.load(join(cache_dir, filename), map_location='cpu')
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
if dwt:
vae.load_state_dict(checkpoint['state_dict'])
else:
vae.model.load_state_dict(checkpoint['state_dict'], strict=False)
print('vae --> ready')
return vae

Loading…
Cancel
Save