Compare commits

...

Author SHA1 Message Date
  oriBetelgeuse cecfac0ec5 add cache for image prompts 4 years ago
  oriBetelgeuse 7d27e7bdb7 add cache for image prompts 4 years ago
4 changed files with 11 additions and 17 deletions
Split View
  1. +1
    -1
      rudalle/__init__.py
  2. +10
    -9
      rudalle/dalle/transformer.py
  3. +0
    -4
      rudalle/image_prompts.py
  4. +0
    -3
      rudalle/pipelines.py

+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc5'
__version__ = '0.0.1-rc6'

+ 10
- 9
rudalle/dalle/transformer.py View File

@ -206,7 +206,7 @@ class DalleSelfAttention(torch.nn.Module):
self.num_attention_heads = num_attention_heads
self.hidden_size_per_attention_head = divide(hidden_size, num_attention_heads)
self.query_key_value = torch.nn.Linear(hidden_size, 3*hidden_size)
self.query_key_value = torch.nn.Linear(hidden_size, 3 * hidden_size)
self.attention_dropout = torch.nn.Dropout(attention_dropout_prob)
# Output.
@ -248,12 +248,12 @@ class DalleSelfAttention(torch.nn.Module):
attention_scores = (attention_scores_scaled - attention_scores_scaled_maxes) * alpha
return attention_scores
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False,):
def forward(self, hidden_states, ltor_mask, has_cache=False, use_cache=False, ):
# hidden_states: [b, s, h]
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :])
mixed_x_layer = self.query_key_value(hidden_states[:, self.past_key.shape[-2]:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
@ -277,6 +277,10 @@ class DalleSelfAttention(torch.nn.Module):
query_layer=query_layer, key_layer=key_layer, ltor_mask=ltor_mask
)
if use_cache and has_cache:
extra_cache_size = hidden_states.shape[-2] - self.past_key.shape[-2]
attention_scores = attention_scores[..., -extra_cache_size:, :]
if use_cache:
self.past_key = key_layer
self.past_value = value_layer
@ -286,9 +290,6 @@ class DalleSelfAttention(torch.nn.Module):
self.past_output = None
has_cache = False
if use_cache and has_cache:
attention_scores = attention_scores[..., -1:, :]
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
@ -338,16 +339,16 @@ class DalleMLP(torch.nn.Module):
def __init__(self, hidden_size, output_dropout_prob):
super(DalleMLP, self).__init__()
# Project to 4h.
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4*hidden_size)
self.dense_h_to_4h = torch.nn.Linear(hidden_size, 4 * hidden_size)
# Project back to h.
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dense_4h_to_h = torch.nn.Linear(4 * hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
hidden_states = hidden_states[:, -1:]
hidden_states = hidden_states[:, self.past_x.shape[-2]:]
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)


+ 0
- 4
rudalle/image_prompts.py View File

@ -18,7 +18,6 @@ class ImagePrompts:
self.device = device
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
self.allow_cache = True
def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
@ -34,9 +33,6 @@ class ImagePrompts:
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)
if borders['right'] + borders['left'] + borders['down'] != 0:
self.allow_cache = False # TODO fix cache in attention
bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:


+ 0
- 3
rudalle/pipelines.py View File

@ -39,9 +39,6 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(chunk_bs, 1)
if use_cache and image_prompts.allow_cache is False:
print('Warning: use_cache changed to False')
use_cache = False
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:


Loading…
Cancel
Save