Browse Source

add cache for image prompts

feature/cache_image_prompts
oriBetelgeuse 4 years ago
parent
commit
cecfac0ec5
1 changed files with 3 additions and 9 deletions
  1. +3
    -9
      rudalle/dalle/transformer.py

+ 3
- 9
rudalle/dalle/transformer.py View File

@ -214,7 +214,6 @@ class DalleSelfAttention(torch.nn.Module):
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
# Cache
self.cache_size = 0
self.past_key = None
self.past_value = None
self.past_output = None
@ -254,7 +253,7 @@ class DalleSelfAttention(torch.nn.Module):
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
mixed_x_layer = self.query_key_value(hidden_states[:, self.cache_size:, :])
mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
@ -279,11 +278,10 @@ class DalleSelfAttention(torch.nn.Module):
)
if use_cache and has_cache:
extra_cache_size = hidden_states.shape[-2] - self.cache_size
extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2]
attention_scores = attention_scores[..., -extra_cache_size:, :]
if use_cache:
self.cache_size = hidden_states.shape[-2]
self.past_key = key_layer
self.past_value = value_layer
else:
@ -346,14 +344,11 @@ class DalleMLP(torch.nn.Module):
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.cache_size = 0
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
extra_cache_size = hidden_states.shape[-2] - self.cache_size
self.cache_size += extra_cache_size
hidden_states = hidden_states[:, -extra_cache_size:]
hidden_states = hidden_states[:, self.past_x.shape[-2]:]
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
@ -366,7 +361,6 @@ class DalleMLP(torch.nn.Module):
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.cache_size = hidden_states.shape[-2]
self.past_x = x
has_cache = True


Loading…
Cancel
Save