You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

378 lines
14 KiB

# -*- coding: utf-8 -*-
import math
import torch
from torch.nn import LayerNorm
from .utils import divide, split_tensor_along_last_dim
@torch.jit.script
def gelu_impl(x):
"""OpenAI's gelu implementation."""
return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x)))
def gelu(x):
return gelu_impl(x)
class DalleTransformer(torch.nn.Module):
"""
This module takes input from embedding layer and it's output can
be used directly by a logit layer. It consists of L (num-layers)
blocks of:
layer norm
self attention
residual connection
layer norm
mlp
residual connection
followed by a final layer norm.
Arguments:
num_layers: Number of transformer layers.
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
"""
_mask_map = []
def __init__(self, num_layers, hidden_size, num_attention_heads, attention_dropout_prob, output_dropout_prob,
layernorm_epsilon=1.0e-5, cogview_sandwich_layernorm=False, cogview_pb_relax=False):
super(DalleTransformer, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
# Transformer layers.
self.layers = torch.nn.ModuleList([
DalleTransformerLayer(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
cogview_sandwich_layernorm=cogview_sandwich_layernorm,
cogview_pb_relax=cogview_pb_relax,
) for _ in range(num_layers)
])
# Final layer norm before output.
self.final_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
def forward(self, hidden_states, attention_mask, has_cache, use_cache):
for i, layer in enumerate(self.layers):
mask = attention_mask
if len(self._mask_map):
layer_mask = self._mask_map[i][:mask.size(2), :mask.size(3)]
mask = torch.mul(attention_mask, layer_mask)
hidden_states, present_has_cache = layer(hidden_states, mask, has_cache=has_cache, use_cache=use_cache)
output = self.final_layernorm(hidden_states)
return output, present_has_cache
class DalleTransformerLayer(torch.nn.Module):
"""
A single layer transformer.
We use the following notation:
h: hidden size
n: number of attention heads
b: batch size
s: sequence length
Transformer layer takes input with size [b, s, h] and returns an
output of the same size.
Arguments:
hidden_size: The hidden size of the self attention.
num_attention_heads: number of attention head in the self
attention.
attention_dropout_prob: dropout probability of the attention
score in self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
layernorm_epsilon: epsilon used in layernorm to avoid
division by zero.
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
layernorm_epsilon,
cogview_sandwich_layernorm=False,
cogview_pb_relax=False):
super(DalleTransformerLayer, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_sandwich_layernorm = cogview_sandwich_layernorm
self.cogview_pb_relax = cogview_pb_relax
# Layernorm on the input data.
self.input_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
if self.cogview_sandwich_layernorm:
self.before_first_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
self.before_second_addition_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# Self attention.
self.attention = DalleSelfAttention(
hidden_size,
num_attention_heads,
attention_dropout_prob,
output_dropout_prob,
cogview_pb_relax=cogview_pb_relax
)
# Layernorm on the input data.
self.post_attention_layernorm = LayerNorm(hidden_size, eps=layernorm_epsilon)
# MLP
self.mlp = DalleMLP(hidden_size, output_dropout_prob)
def forward(self, hidden_states, ltor_mask, has_cache, use_cache):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Layer norm at the begining of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, att_has_cache = self.attention(
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
attention_output = self.before_first_addition_layernorm(attention_output)
# Residual connection.
layernorm_input = hidden_states + attention_output
# Layer norm post the self attention.
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output, mlp_has_cache = self.mlp(
layernorm_output, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
mlp_output = self.before_second_addition_layernorm(mlp_output)
# Second residual connection.
output = layernorm_input + mlp_output
return output, att_has_cache and mlp_has_cache
class DalleSelfAttention(torch.nn.Module):
"""
Self-attention layer takes input with size [b, s, h] where b is
the batch size, s is the sequence length, and h is the hidden size
and creates output of the same size.
Arguments:
hidden_size: total hidden size of the layer (h).
num_attention_heads: number of attention heads (n). Note that we
require n to be divisible by number of GPUs
used to parallelize the model. Also, we
require hidden size to be divisible by n.
attention_dropout_prob: dropout probability for the attention scores.
output_dropout_prob: dropout probability for the output.
We use the following notation:
h: hidden_size
n: num_attention_heads
p: number of partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
"""
def __init__(self, hidden_size, num_attention_heads,
attention_dropout_prob, output_dropout_prob, cogview_pb_relax=False):
super(DalleSelfAttention, self).__init__()
# CogView stabilization of training features, see chapter 2.4 https://arxiv.org/pdf/2105.13290.pdf
self.cogview_pb_relax = cogview_pb_relax
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
self.dense = torch.nn.Linear(hidden_size, hidden_size)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
# Cache
self.cash_size = 0
self.past_key = None
self.past_value = None
self.past_output = None
def _transpose_for_scores(self, tensor):
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
tensor = tensor.view(*new_tensor_shape)
return tensor.permute(0, 2, 1, 3)
def _calculate_attention_scores(self, query_layer, key_layer, ltor_mask):
key_t = key_layer.transpose(-1, -2)
if self.cogview_pb_relax:
attention_scores = torch.matmul(
query_layer / math.sqrt(self.hidden_size_per_attention_head),
key_t
)
else:
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
if self.cogview_pb_relax:
# normalize attention scores. Should not affect resulting softmax value
alpha = 32
attention_scores_scaled = attention_scores / alpha
attention_scores_scaled_maxes, _ = attention_scores_scaled.detach().view(
[attention_scores.size(0), attention_scores.size(1), -1]
).max(dim=-1) # max per head per sample
attention_scores_scaled_maxes = attention_scores_scaled_maxes.unsqueeze(-1).unsqueeze(-1).expand(
[-1, -1, attention_scores.size(2), attention_scores.size(3)]
) # expand to [b, np, s, s]
attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha
return attention_scores
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
extra_cache_size = hidden_states.shape[-2] - self.chache_size
mixed_x_layer = self.query_key_value(hidden_states[:, -extra_cache_size:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
(mixed_query_layer,
mixed_key_layer,
mixed_value_layer) = split_tensor_along_last_dim(mixed_x_layer, 3)
query_layer = self._transpose_for_scores(mixed_query_layer)
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Can be simplified, but I didn't for readability's sake
if use_cache and has_cache:
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
else:
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
if use_cache:
self.chache_size = hidden_states.shape[-2]
self.past_key = key_layer
self.past_value = value_layer
else:
self.past_key = None
self.past_value = None
self.past_output = None
has_cache = False
if use_cache and has_cache:
attention_scores = attention_scores[..., -extra_cache_size:, :]
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.attention_dropout(attention_probs)
# Context layer.
# [b, np, s, hn]
context_layer = torch.matmul(attention_probs, value_layer)
# [b, s, np, hn]
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size,)
# [b, s, hp]
context_layer = context_layer.view(*new_context_layer_shape)
# Output. [b, s, h]
output = self.dense(context_layer)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
output = torch.cat((self.past_output, output), dim=-2)
self.past_output = output
else:
self.past_output = output
has_cache = True
output = self.output_dropout(output)
return output, has_cache
class DalleMLP(torch.nn.Module):
"""
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform gelu transformation, and project the
state back into h hidden dimension. At the end, dropout is also
applied.
Arguments:
hidden_size: The hidden size of the self attention.
output_dropout_prob: dropout probability for the outputs
after self attention and final output.
"""
def __init__(self, hidden_size, output_dropout_prob):
super(DalleMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size)
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.cache_size = 0
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
extra_cache_size = hidden_states.shape[-2] - self.cache_size
self.cache_size += extra_cache_size
hidden_states = hidden_states[:, -extra_cache_size:]
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
x = gelu(x)
# [b, s, h]
x = self.dense_4h_to_h(x)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.cache_size = hidden_states.shape[-2]
self.past_x = x
has_cache = True
else:
self.past_x = None
has_cache = False
output = self.dropout(x)
return output, has_cache