|
|
|
@ -4,7 +4,6 @@ import torch.nn.functional as F |
|
|
|
from einops import rearrange |
|
|
|
|
|
|
|
from .utils import exists, is_empty, init_method_normal |
|
|
|
from .image_attention import get_conv_mask, get_row_mask, get_col_mask |
|
|
|
|
|
|
|
from .transformer import DalleTransformer |
|
|
|
|
|
|
|
@ -23,14 +22,11 @@ class DalleModel(torch.nn.Module): |
|
|
|
image_tokens_per_dim=32, |
|
|
|
image_vocab_size=16384, |
|
|
|
loss_img_weight=7, |
|
|
|
fp16=False, |
|
|
|
use_masks=True, |
|
|
|
cogview_sandwich_layernorm=False, |
|
|
|
cogview_pb_relax=False): |
|
|
|
|
|
|
|
super(DalleModel, self).__init__() |
|
|
|
self.device = device |
|
|
|
self.fp16 = fp16 |
|
|
|
self.image_tokens_per_dim = image_tokens_per_dim |
|
|
|
self.image_seq_length = image_tokens_per_dim ** 2 |
|
|
|
self.text_seq_length = text_seq_length |
|
|
|
@ -39,13 +35,6 @@ class DalleModel(torch.nn.Module): |
|
|
|
self.vocab_size = vocab_size |
|
|
|
self.loss_img_weight = loss_img_weight |
|
|
|
|
|
|
|
# TODO "to" |
|
|
|
mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim) |
|
|
|
if use_masks: |
|
|
|
self._mask_map = mask_map |
|
|
|
else: |
|
|
|
self._mask_map = [] |
|
|
|
|
|
|
|
init_method = init_method_normal(std=0.02) |
|
|
|
|
|
|
|
self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size) |
|
|
|
@ -74,35 +63,15 @@ class DalleModel(torch.nn.Module): |
|
|
|
num_attention_heads, |
|
|
|
attention_dropout_prob, |
|
|
|
output_dropout_prob, |
|
|
|
text_seq_length=text_seq_length, |
|
|
|
image_tokens_per_dim=image_tokens_per_dim, |
|
|
|
cogview_sandwich_layernorm=cogview_sandwich_layernorm, |
|
|
|
cogview_pb_relax=cogview_pb_relax, |
|
|
|
) |
|
|
|
self.transformer._mask_map = self._mask_map |
|
|
|
|
|
|
|
def get_param(self, item): |
|
|
|
return getattr(self, item) |
|
|
|
|
|
|
|
def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim): |
|
|
|
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device) |
|
|
|
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device) |
|
|
|
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device) |
|
|
|
if self.fp16: |
|
|
|
row_mask = row_mask.half() |
|
|
|
col_mask = col_mask.half() |
|
|
|
conv_mask = conv_mask.half() |
|
|
|
self.register_buffer('row_mask', row_mask) |
|
|
|
self.register_buffer('col_mask', col_mask) |
|
|
|
self.register_buffer('conv_mask', conv_mask) |
|
|
|
mask_map = [] |
|
|
|
for i in range(num_layers): |
|
|
|
if ((i - 1) % 4 == 0): |
|
|
|
mask_map.append(col_mask) |
|
|
|
elif i != num_layers - 1: |
|
|
|
mask_map.append(row_mask) |
|
|
|
else: |
|
|
|
mask_map.append(conv_mask) |
|
|
|
return mask_map |
|
|
|
|
|
|
|
def get_image_pos_embeddings(self, image_input_ids, past_length=0): |
|
|
|
input_shape = image_input_ids.size() |
|
|
|
row_ids = torch.arange(past_length, input_shape[-1] + past_length, |
|
|
|
@ -172,6 +141,4 @@ class DalleModel(torch.nn.Module): |
|
|
|
|
|
|
|
def to(self, device, *args, **kwargs): |
|
|
|
self.device = device |
|
|
|
self._mask_map = [mask.to(device) for mask in self._mask_map] |
|
|
|
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map] |
|
|
|
return super().to(device, *args, **kwargs) |