Browse Source

optimize generation caching (#12)

Over 10x speedup, adds MLP caching and optimizes attention caching.
Uses changes from https://t.co/BTwo6NKq9H.
pull/13/head
neverix 5 years ago
committed by GitHub
parent
commit
47de7a2fd5
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 48 additions and 9 deletions
  1. +48
    -9
      rudalle/dalle/transformer.py

+ 48
- 9
rudalle/dalle/transformer.py View File

@ -146,7 +146,7 @@ class DalleTransformerLayer(torch.nn.Module):
layernorm_output = self.input_layernorm(hidden_states)
# Self attention.
attention_output, has_cache = self.attention(
attention_output, att_has_cache = self.attention(
layernorm_output, ltor_mask, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
@ -159,7 +159,8 @@ class DalleTransformerLayer(torch.nn.Module):
layernorm_output = self.post_attention_layernorm(layernorm_input)
# MLP.
mlp_output = self.mlp(layernorm_output)
mlp_output, mlp_has_cache = self.mlp(
layernorm_output, has_cache=has_cache, use_cache=use_cache)
if self.cogview_sandwich_layernorm:
mlp_output = self.before_second_addition_layernorm(mlp_output)
@ -167,7 +168,7 @@ class DalleTransformerLayer(torch.nn.Module):
# Second residual connection.
output = layernorm_input + mlp_output
return output, has_cache
return output, att_has_cache and mlp_has_cache
class DalleSelfAttention(torch.nn.Module):
@ -212,6 +213,11 @@ class DalleSelfAttention(torch.nn.Module):
self.dense = torch.nn.Linear(hidden_size, hidden_size)
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
# Cache
self.past_key = None
self.past_value = None
self.past_output = None
def _transpose_for_scores(self, tensor):
""" Transpose a 3D tensor [b, s, np*hn] into a 4D tensor with size [b, np, s, hn]. """
new_tensor_shape = tensor.size()[:-1] + (self.num_attention_heads, self.hidden_size_per_attention_head)
@ -227,6 +233,7 @@ class DalleSelfAttention(torch.nn.Module):
)
else:
attention_scores = torch.matmul(query_layer, key_t) / math.sqrt(self.hidden_size_per_attention_head)
ltor_mask = ltor_mask[:, :, -attention_scores.shape[-2]:]
attention_scores = torch.mul(attention_scores, ltor_mask) - 10000.0 * (1.0 - ltor_mask)
if self.cogview_pb_relax:
# normalize attention scores. Should not affect resulting softmax value
@ -258,10 +265,10 @@ class DalleSelfAttention(torch.nn.Module):
key_layer = self._transpose_for_scores(mixed_key_layer)
value_layer = self._transpose_for_scores(mixed_value_layer)
# Can be simplified, but I didn't for readability's sake
if use_cache and has_cache:
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
query_layer = torch.cat((self.past_query, query_layer), dim=-2)
key_layer = torch.cat((self.past_key, key_layer), dim=-2)
value_layer = torch.cat((self.past_value, value_layer), dim=-2)
attention_scores = self._calculate_attention_scores(
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
@ -271,13 +278,17 @@ class DalleSelfAttention(torch.nn.Module):
)
if use_cache:
self.past_query = query_layer
self.past_key = key_layer
self.past_value = value_layer
has_cache = True
else:
self.past_key = None
self.past_value = None
self.past_output = None
has_cache = False
if use_cache and has_cache:
attention_scores = attention_scores[..., -1:, :]
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
@ -298,6 +309,16 @@ class DalleSelfAttention(torch.nn.Module):
# Output. [b, s, h]
output = self.dense(context_layer)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
output = torch.cat((self.past_output, output), dim=-2)
self.past_output = output
else:
self.past_output = output
has_cache = True
output = self.output_dropout(output)
return output, has_cache
@ -321,12 +342,30 @@ class DalleMLP(torch.nn.Module):
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
hidden_states = hidden_states[:, -1:]
def forward(self, hidden_states):
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
x = gelu(x)
# [b, s, h]
x = self.dense_4h_to_h(x)
if use_cache:
# Can be simplified, but I didn't for readability's sake
if has_cache:
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.past_x = x
has_cache = True
else:
self.past_x = None
has_cache = False
output = self.dropout(x)
return output
return output, has_cache

Loading…
Cancel
Save