|
|
@ -214,7 +214,6 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
self.output_dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
self.output_dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
|
|
|
|
|
|
# Cache |
|
|
# Cache |
|
|
self.cache_size = 0 |
|
|
|
|
|
self.past_key = None |
|
|
self.past_key = None |
|
|
self.past_value = None |
|
|
self.past_value = None |
|
|
self.past_output = None |
|
|
self.past_output = None |
|
|
@ -254,7 +253,7 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
# ltor_mask: [1, 1, s, s] |
|
|
# ltor_mask: [1, 1, s, s] |
|
|
# Attention heads. [b, s, hp] |
|
|
# Attention heads. [b, s, hp] |
|
|
if has_cache and use_cache: |
|
|
if has_cache and use_cache: |
|
|
mixed_x_layer = self.query_key_value(hidden_states[:, self.cache_size:, :]) |
|
|
|
|
|
|
|
|
mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :]) |
|
|
else: |
|
|
else: |
|
|
mixed_x_layer = self.query_key_value(hidden_states) |
|
|
mixed_x_layer = self.query_key_value(hidden_states) |
|
|
|
|
|
|
|
|
@ -279,11 +278,10 @@ class DalleSelfAttention(torch.nn.Module): |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
if use_cache and has_cache: |
|
|
if use_cache and has_cache: |
|
|
extra_cache_size = hidden_states.shape[-2] - self.cache_size |
|
|
|
|
|
|
|
|
extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2] |
|
|
attention_scores = attention_scores[..., -extra_cache_size:, :] |
|
|
attention_scores = attention_scores[..., -extra_cache_size:, :] |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
if use_cache: |
|
|
self.cache_size = hidden_states.shape[-2] |
|
|
|
|
|
self.past_key = key_layer |
|
|
self.past_key = key_layer |
|
|
self.past_value = value_layer |
|
|
self.past_value = value_layer |
|
|
else: |
|
|
else: |
|
|
@ -346,14 +344,11 @@ class DalleMLP(torch.nn.Module): |
|
|
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) |
|
|
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) |
|
|
self.dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
self.dropout = torch.nn.Dropout(output_dropout_prob) |
|
|
# MLP cache |
|
|
# MLP cache |
|
|
self.cache_size = 0 |
|
|
|
|
|
self.past_x = None |
|
|
self.past_x = None |
|
|
|
|
|
|
|
|
def forward(self, hidden_states, has_cache=False, use_cache=False): |
|
|
def forward(self, hidden_states, has_cache=False, use_cache=False): |
|
|
if has_cache and use_cache: |
|
|
if has_cache and use_cache: |
|
|
extra_cache_size = hidden_states.shape[-2] - self.cache_size |
|
|
|
|
|
self.cache_size += extra_cache_size |
|
|
|
|
|
hidden_states = hidden_states[:, -extra_cache_size:] |
|
|
|
|
|
|
|
|
hidden_states = hidden_states[:, self.past_x.shape[-2]:] |
|
|
|
|
|
|
|
|
# [b, s, 4hp] |
|
|
# [b, s, 4hp] |
|
|
x = self.dense_h_to_4h(hidden_states) |
|
|
x = self.dense_h_to_4h(hidden_states) |
|
|
@ -366,7 +361,6 @@ class DalleMLP(torch.nn.Module): |
|
|
x = torch.cat((self.past_x, x), dim=-2) |
|
|
x = torch.cat((self.past_x, x), dim=-2) |
|
|
self.past_x = x |
|
|
self.past_x = x |
|
|
else: |
|
|
else: |
|
|
self.cache_size = hidden_states.shape[-2] |
|
|
|
|
|
self.past_x = x |
|
|
self.past_x = x |
|
|
|
|
|
|
|
|
has_cache = True |
|
|
has_cache = True |
|
|
|