|
|
|
@ -169,3 +169,9 @@ class DalleModel(torch.nn.Module): |
|
|
|
|
|
|
|
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1) |
|
|
|
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()} |
|
|
|
|
|
|
|
def to(self, device, *args, **kwargs): |
|
|
|
self.device = device |
|
|
|
self._mask_map = [mask.to(device) for mask in self._mask_map] |
|
|
|
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map] |
|
|
|
return super().to(device, *args, **kwargs) |