Browse Source

to device support (#10)

* to device support
pull/13/head
Alex 5 years ago
committed by GitHub
parent
commit
0b3d6488c5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 16 additions and 2 deletions
  1. +1
    -0
      .gitignore
  2. +1
    -1
      README.md
  3. +1
    -1
      rudalle/__init__.py
  4. +3
    -0
      rudalle/dalle/__init__.py
  5. +4
    -0
      rudalle/dalle/fp16.py
  6. +6
    -0
      rudalle/dalle/model.py

+ 1
- 0
.gitignore View File

@ -166,3 +166,4 @@ runs/
jupyters/custom_*
*logs/
.DS_store

+ 1
- 1
README.md View File

@ -2,7 +2,7 @@
### Generate images from texts
```
pip install rudalle==0.0.1rc2
pip install rudalle==0.0.1rc3
```
### 🤗 HF Models:
[ruDALL-E Malevich (XL)](https://huggingface.co/sberbank-ai/rudalle-Malevich)


+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc2'
__version__ = '0.0.1-rc3'

+ 3
- 0
rudalle/dalle/__init__.py View File

@ -59,6 +59,9 @@ def get_rudalle_model(name, pretrained=True, fp16=False, device='cpu', cache_dir
# TODO docstring
assert name in MODELS
if fp16 and device == 'cpu':
print('Warning! Using both fp16 and cpu doesnt support. You can use cuda device or turn off fp16.')
config = MODELS[name]
model = DalleModel(device=device, fp16=fp16, **config['model_params'])
if pretrained:


+ 4
- 0
rudalle/dalle/fp16.py View File

@ -58,3 +58,7 @@ class FP16Module(nn.Module):
def get_param(self, item):
return self.module.get_param(item)
def to(self, device, *args, **kwargs):
self.module.to(device)
return super().to(device, *args, **kwargs)

+ 6
- 0
rudalle/dalle/model.py View File

@ -169,3 +169,9 @@ class DalleModel(torch.nn.Module):
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}
def to(self, device, *args, **kwargs):
self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs)

Loading…
Cancel
Save