|
|
|
@ -146,7 +146,7 @@ class DalleTransformerLayer(torch.nn.Module): |
|
|
|
layernorm_output = self.input_layernorm(hidden_states) |
|
|
|
|
|
|
|
# Self attention. |
|
|
|
attention_output, has_cache = self.attention( |
|
|
|
attention_output, att_has_cache = self.attention( |
|
|
|
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache) |
|
|
|
|
|
|
|
if self.cogview_sandwich_layernorm: |
|
|
|
@ -159,7 +159,8 @@ class DalleTransformerLayer(torch.nn.Module): |
|
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input) |
|
|
|
|
|
|
|
# MLP. |
|
|
|
mlp_output = self.mlp(layernorm_output) |
|
|
|
mlp_output, mlp_has_cache = self.mlp( |
|
|
|
layernorm_output, has_cache=has_cache, use_cache=use_cache) |
|
|
|
|
|
|
|
if self.cogview_sandwich_layernorm: |
|
|
|
mlp_output = self.before_second_addition_layernorm(mlp_output) |
|
|
|
@ -167,7 +168,7 @@ class DalleTransformerLayer(torch.nn.Module): |
|
|
|
# Second residual connection. |
|
|
|
output = layernorm_input + mlp_output |
|
|
|
|
|
|
|
return output, has_cache |
|
|
|
return output, att_has_cache and mlp_has_cache |
|
|
|
|
|
|
|
|
|
|
|
class DalleSelfAttention(torch.nn.Module): |
|
|
|
@ -212,6 +213,11 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
self.dense = torch.nn.Linear(hidden_size, hidden_size) |
|
|
|
self.output_dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
|
|
|
|
|
# Cache |
|
|
|
self.past_key = None |
|
|
|
self.past_value = None |
|
|
|
self.past_output = None |
|
|
|
|
|
|
|
def _transpose_for_scores(self, tensor): |
|
|
|
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """ |
|
|
|
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head) |
|
|
|
@ -227,6 +233,7 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
) |
|
|
|
else: |
|
|
|
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head) |
|
|
|
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:] |
|
|
|
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask) |
|
|
|
if self.cogview_pb_relax: |
|
|
|
# normalize attention scores. Should not affect resulting softmax value |
|
|
|
@ -258,10 +265,10 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
key_layer = self._transpose_for_scores(mixed_key_layer) |
|
|
|
value_layer = self._transpose_for_scores(mixed_value_layer) |
|
|
|
|
|
|
|
# Can be simplified, but I didn't for readability's sake |
|
|
|
if use_cache and has_cache: |
|
|
|
value_layer = torch.cat((self.past_value, value_layer), dim=-2) |
|
|
|
query_layer = torch.cat((self.past_query, query_layer), dim=-2) |
|
|
|
key_layer = torch.cat((self.past_key, key_layer), dim=-2) |
|
|
|
value_layer = torch.cat((self.past_value, value_layer), dim=-2) |
|
|
|
attention_scores = self._calculate_attention_scores( |
|
|
|
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask |
|
|
|
) |
|
|
|
@ -271,13 +278,17 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
) |
|
|
|
|
|
|
|
if use_cache: |
|
|
|
self.past_query = query_layer |
|
|
|
self.past_key = key_layer |
|
|
|
self.past_value = value_layer |
|
|
|
has_cache = True |
|
|
|
else: |
|
|
|
self.past_key = None |
|
|
|
self.past_value = None |
|
|
|
self.past_output = None |
|
|
|
has_cache = False |
|
|
|
|
|
|
|
if use_cache and has_cache: |
|
|
|
attention_scores = attention_scores[..., -1:, :] |
|
|
|
|
|
|
|
# Attention probabilities. [b, np, s, s] |
|
|
|
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
@ -298,6 +309,16 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
|
|
|
|
# Output. [b, s, h] |
|
|
|
output = self.dense(context_layer) |
|
|
|
|
|
|
|
if use_cache: |
|
|
|
# Can be simplified, but I didn't for readability's sake |
|
|
|
if has_cache: |
|
|
|
output = torch.cat((self.past_output, output), dim=-2) |
|
|
|
self.past_output = output |
|
|
|
else: |
|
|
|
self.past_output = output |
|
|
|
has_cache = True |
|
|
|
|
|
|
|
output = self.output_dropout(output) |
|
|
|
return output, has_cache |
|
|
|
|
|
|
|
@ -321,12 +342,30 @@ class DalleMLP(torch.nn.Module): |
|
|
|
# Project back to h. |
|
|
|
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) |
|
|
|
self.dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
|
# MLP cache |
|
|
|
self.past_x = None |
|
|
|
|
|
|
|
def forward(self, hidden_states, has_cache=False, use_cache=False): |
|
|
|
if has_cache and use_cache: |
|
|
|
hidden_states = hidden_states[:, -1:] |
|
|
|
|
|
|
|
def forward(self, hidden_states): |
|
|
|
# [b, s, 4hp] |
|
|
|
x = self.dense_h_to_4h(hidden_states) |
|
|
|
x = gelu(x) |
|
|
|
# [b, s, h] |
|
|
|
x = self.dense_4h_to_h(x) |
|
|
|
if use_cache: |
|
|
|
# Can be simplified, but I didn't for readability's sake |
|
|
|
if has_cache: |
|
|
|
x = torch.cat((self.past_x, x), dim=-2) |
|
|
|
self.past_x = x |
|
|
|
else: |
|
|
|
self.past_x = x |
|
|
|
|
|
|
|
has_cache = True |
|
|
|
else: |
|
|
|
self.past_x = None |
|
|
|
has_cache = False |
|
|
|
output = self.dropout(x) |
|
|
|
return output |
|
|
|
|
|
|
|
return output, has_cache |