From cecfac0ec5115be79e1ad36afe0cdf1d60da5b43 Mon Sep 17 00:00:00 2001 From: oriBetelgeuse Date: Sun, 7 Nov 2021 16:15:10 +0300 Subject: [PATCH] add cache for image prompts --- rudalle/dalle/transformer.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/rudalle/dalle/transformer.py b/rudalle/dalle/transformer.py index 1f0c4f8..8dabffc 100755 --- a/rudalle/dalle/transformer.py +++ b/rudalle/dalle/transformer.py @@ -214,7 +214,6 @@ class DalleSelfAttention(torch.nn.Module): self.output_dropout = torch.nn.Dropout(output_dropout_prob) # Cache - self.cache_size = 0 self.past_key = None self.past_value = None self.past_output = None @@ -254,7 +253,7 @@ class DalleSelfAttention(torch.nn.Module): # ltor_mask: [1, 1, s, s] # Attention heads. [b, s, hp] if has_cache and use_cache: - mixed_x_layer = self.query_key_value(hidden_states[:, self.cache_size:, :]) + mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :]) else: mixed_x_layer = self.query_key_value(hidden_states) @@ -279,11 +278,10 @@ class DalleSelfAttention(torch.nn.Module): ) if use_cache and has_cache: - extra_cache_size = hidden_states.shape[-2] - self.cache_size + extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2] attention_scores = attention_scores[..., -extra_cache_size:, :] if use_cache: - self.cache_size = hidden_states.shape[-2] self.past_key = key_layer self.past_value = value_layer else: @@ -346,14 +344,11 @@ class DalleMLP(torch.nn.Module): self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size) self.dropout = torch.nn.Dropout(output_dropout_prob) # MLP cache - self.cache_size = 0 self.past_x = None def forward(self, hidden_states, has_cache=False, use_cache=False): if has_cache and use_cache: - extra_cache_size = hidden_states.shape[-2] - self.cache_size - self.cache_size += extra_cache_size - hidden_states = hidden_states[:, -extra_cache_size:] + hidden_states = hidden_states[:, self.past_x.shape[-2]:] # [b, s, 4hp] x = self.dense_h_to_4h(hidden_states) @@ -366,7 +361,6 @@ class DalleMLP(torch.nn.Module): x = torch.cat((self.past_x, x), dim=-2) self.past_x = x else: - self.cache_size = hidden_states.shape[-2] self.past_x = x has_cache = True