You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

177 lines
7.2 KiB

# -*- coding: utf-8 -*-
import torch
import torch.nn.functional as F
from einops import rearrange
from .utils import exists, is_empty, init_method_normal
from .image_attention import get_conv_mask, get_row_mask, get_col_mask
from .transformer import DalleTransformer
class DalleModel(torch.nn.Module):
def __init__(self,
device,
num_layers,
vocab_size,
hidden_size,
num_attention_heads,
embedding_dropout_prob,
attention_dropout_prob,
output_dropout_prob,
text_seq_length=128,
image_tokens_per_dim=32,
image_vocab_size=16384,
loss_img_weight=7,
fp16=False,
use_masks=True,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False):
super(DalleModel, self).__init__()
self.device = device
self.fp16 = fp16
self.image_tokens_per_dim = image_tokens_per_dim
self.image_seq_length = image_tokens_per_dim ** 2
self.text_seq_length = text_seq_length
self.total_seq_length = self.text_seq_length + self.image_seq_length
self.total_vocab_size = vocab_size + image_vocab_size
self.vocab_size = vocab_size
self.loss_img_weight = loss_img_weight
# TODO "to"
mask_map = self.prepare_image_masks(num_layers, text_seq_length, image_tokens_per_dim)
if use_masks:
self._mask_map = mask_map
else:
self._mask_map = []
init_method = init_method_normal(std=0.02)
self.text_embeddings = torch.nn.Embedding(vocab_size, hidden_size)
self.image_embeddings = torch.nn.Embedding(image_vocab_size, hidden_size)
# Position embedding (serial).
self.text_pos_embeddings = torch.nn.Embedding(text_seq_length + 1, hidden_size)
self.image_row_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
self.image_col_embeddings = torch.nn.Embedding(image_tokens_per_dim, hidden_size)
init_method(self.text_pos_embeddings.weight)
init_method(self.image_row_embeddings.weight)
init_method(self.image_col_embeddings.weight)
self.to_logits = torch.nn.Sequential(
torch.nn.LayerNorm(hidden_size),
torch.nn.Linear(hidden_size, self.total_vocab_size),
)
# Embeddings dropout
self.embedding_dropout = torch.nn.Dropout(embedding_dropout_prob)
# Transformer
self.transformer = DalleTransformer(
num_layers,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
)
self.transformer._mask_map = self._mask_map
def get_param(self, item):
return getattr(self, item)
def prepare_image_masks(self, num_layers, text_seq_length, image_tokens_per_dim):
row_mask = get_row_mask(text_seq_length, image_tokens_per_dim).to(self.device)
col_mask = get_col_mask(text_seq_length, image_tokens_per_dim).to(self.device)
conv_mask = get_conv_mask(text_seq_length, image_tokens_per_dim).to(self.device)
if self.fp16:
row_mask = row_mask.half()
col_mask = col_mask.half()
conv_mask = conv_mask.half()
self.register_buffer('row_mask', row_mask)
self.register_buffer('col_mask', col_mask)
self.register_buffer('conv_mask', conv_mask)
mask_map = []
for i in range(num_layers):
if ((i - 1) % 4 == 0):
mask_map.append(col_mask)
elif i != num_layers - 1:
mask_map.append(row_mask)
else:
mask_map.append(conv_mask)
return mask_map
def get_image_pos_embeddings(self, image_input_ids, past_length=0):
input_shape = image_input_ids.size()
row_ids = torch.arange(past_length, input_shape[-1] + past_length,
dtype=torch.long, device=self.device) // self.image_tokens_per_dim
row_ids = row_ids.unsqueeze(0).view(-1, input_shape[-1])
col_ids = torch.arange(past_length, input_shape[-1] + past_length,
dtype=torch.long, device=self.device) % self.image_tokens_per_dim
col_ids = col_ids.unsqueeze(0).view(-1, input_shape[-1])
return self.image_row_embeddings(row_ids) + self.image_col_embeddings(col_ids)
def forward(
self,
input_ids,
attention_mask,
return_loss=False,
has_cache=False,
use_cache=False,
):
text = input_ids[:, :self.text_seq_length]
text_range = torch.arange(self.text_seq_length)
text_range += (self.vocab_size - self.text_seq_length)
text_range = text_range.to(self.device)
text = torch.where(text == 0, text_range, text)
# some hardcode :)
text = F.pad(text, (1, 0), value=2)
text_embeddings = self.text_embeddings(text) + \
self.text_pos_embeddings(torch.arange(text.shape[1], device=self.device))
image_input_ids = input_ids[:, self.text_seq_length:]
if exists(image_input_ids) and not is_empty(image_input_ids):
image_embeddings = self.image_embeddings(image_input_ids) + \
self.get_image_pos_embeddings(image_input_ids, past_length=0)
embeddings = torch.cat((text_embeddings, image_embeddings), dim=1)
else:
embeddings = text_embeddings
# some hardcode :)
if embeddings.shape[1] > self.total_seq_length:
embeddings = embeddings[:, :-1]
alpha = 0.1
embeddings = embeddings * alpha + embeddings.detach() * (1-alpha)
attention_mask = attention_mask[:, :, :embeddings.shape[1], :embeddings.shape[1]]
transformer_output, present_has_cache = self.transformer(
embeddings, attention_mask, has_cache=has_cache, use_cache=use_cache)
logits = self.to_logits(transformer_output)
if return_loss is False:
return logits, present_has_cache
labels = torch.cat((text[:, 1:], image_input_ids), dim=1).contiguous().long()
logits = rearrange(logits, 'b n c -> b c n')
text_logits = logits[:, :self.vocab_size, :self.text_seq_length].contiguous().float()
image_logits = logits[:, self.vocab_size:, self.text_seq_length:].contiguous().float()
loss_text = F.cross_entropy(
text_logits,
labels[:, :self.text_seq_length])
loss_img = F.cross_entropy(
image_logits,
labels[:, self.text_seq_length:])
loss = (loss_text + self.loss_img_weight * loss_img) / (self.loss_img_weight + 1)
return loss, {'text': loss_text.data.detach().float(), 'image': loss_img.data.detach().float()}
def to(self, device, *args, **kwargs):
self.device = device
self._mask_map = [mask.to(device) for mask in self._mask_map]
self.transformer._mask_map = [mask.to(device) for mask in self.transformer._mask_map]
return super().to(device, *args, **kwargs)