|
|
|
@ -25,9 +25,11 @@ MODELS = { |
|
|
|
cogview_sandwich_layernorm=True, |
|
|
|
cogview_pb_relax=True, |
|
|
|
vocab_size=16384+128, |
|
|
|
image_vocab_size=8192, |
|
|
|
image_vocab_size=8192*2, |
|
|
|
# image_vocab_size=8192, |
|
|
|
), |
|
|
|
repo_id='sberbank-ai/rudalle-Malevich', |
|
|
|
# repo_id='sberbank-ai/rudalle-Malevich', # TODO update checkpoint |
|
|
|
repo_id='shonenkov/rudalle-Malevich', |
|
|
|
filename='pytorch_model.bin', |
|
|
|
full_description='', # TODO |
|
|
|
), |
|
|
|
@ -66,7 +68,9 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir |
|
|
|
config_file_url = hf_hub_url(repo_id=config['repo_id'], filename=config['filename']) |
|
|
|
cached_download(config_file_url, cache_dir=cache_dir, force_filename=config['filename']) |
|
|
|
checkpoint = torch.load(os.path.join(cache_dir, config['filename']), map_location='cpu') |
|
|
|
model.load_state_dict(checkpoint) |
|
|
|
if 'module' in checkpoint: |
|
|
|
checkpoint = checkpoint['module'] |
|
|
|
model.load_state_dict(checkpoint, strict=False) |
|
|
|
if fp16: |
|
|
|
model = FP16Module(model) |
|
|
|
model.eval() |
|
|
|
|