|
|
|
@ -206,7 +206,7 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
self.num_attention_heads = num_attention_heads |
|
|
|
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads) |
|
|
|
|
|
|
|
self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size) |
|
|
|
self.query_key_value = torch.nn.Linear(hidden_size, 3 * hidden_size) |
|
|
|
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob) |
|
|
|
|
|
|
|
# Output. |
|
|
|
@ -248,12 +248,12 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha |
|
|
|
return attention_scores |
|
|
|
|
|
|
|
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,): |
|
|
|
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False, ): |
|
|
|
# hidden_states: [b, s, h] |
|
|
|
# ltor_mask: [1, 1, s, s] |
|
|
|
# Attention heads. [b, s, hp] |
|
|
|
if has_cache and use_cache: |
|
|
|
mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :]) |
|
|
|
mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :]) |
|
|
|
else: |
|
|
|
mixed_x_layer = self.query_key_value(hidden_states) |
|
|
|
|
|
|
|
@ -277,6 +277,10 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask |
|
|
|
) |
|
|
|
|
|
|
|
if use_cache and has_cache: |
|
|
|
extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2] |
|
|
|
attention_scores = attention_scores[..., -extra_cache_size:, :] |
|
|
|
|
|
|
|
if use_cache: |
|
|
|
self.past_key = key_layer |
|
|
|
self.past_value = value_layer |
|
|
|
@ -286,9 +290,6 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
|
self.past_output = None |
|
|
|
has_cache = False |
|
|
|
|
|
|
|
if use_cache and has_cache: |
|
|
|
attention_scores = attention_scores[..., -1:, :] |
|
|
|
|
|
|
|
# Attention probabilities. [b, np, s, s] |
|
|
|
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
@ -338,16 +339,16 @@ class DalleMLP(torch.nn.Module): |
|
|
|
def __init__(self, hidden_size, output_dropout_prob): |
|
|
|
super(DalleMLP, self).__init__() |
|
|
|
# Project to 4h. |
|
|
|
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size) |
|
|
|
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4 * hidden_size) |
|
|
|
# Project back to h. |
|
|
|
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size) |
|
|
|
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) |
|
|
|
self.dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
|
# MLP cache |
|
|
|
self.past_x = None |
|
|
|
|
|
|
|
def forward(self, hidden_states, has_cache=False, use_cache=False): |
|
|
|
if has_cache and use_cache: |
|
|
|
hidden_states = hidden_states[:, -1:] |
|
|
|
hidden_states = hidden_states[:, self.past_x.shape[-2]:] |
|
|
|
|
|
|
|
# [b, s, 4hp] |
|
|
|
x = self.dense_h_to_4h(hidden_states) |
|
|
|
|