Browse Source

checkpoint

pull/4/head
shonenkov 5 years ago
parent
commit
f10240e430
1 changed files with 7 additions and 3 deletions
  1. +7
    -3
      rudalle/dalle/__init__.py

+ 7
- 3
rudalle/dalle/__init__.py View File

@ -25,9 +25,11 @@ MODELS = {
cogview_sandwich_layernorm=True,
cogview_pb_relax=True,
vocab_size=16384+128,
image_vocab_size=8192,
image_vocab_size=8192*2,
# image_vocab_size=8192,
),
repo_id='sberbank-ai/rudalle-Malevich',
# repo_id='sberbank-ai/rudalle-Malevich', # TODO update checkpoint
repo_id='shonenkov/rudalle-Malevich',
filename='pytorch_model.bin',
full_description='', # TODO
),
@ -66,7 +68,9 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename'])
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename'])
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu')
model.load_state_dict(checkpoint)
if 'module' in checkpoint:
checkpoint = checkpoint['module']
model.load_state_dict(checkpoint, strict=False)
if fp16:
model = FP16Module(model)
model.eval()


Loading…
Cancel
Save