Compare commits

...

11 Commits

Author SHA1 Message Date
  oriBetelgeuse d1ff5f07c4 Add cache for image prompts and crop first for all type of image prompts 4 years ago
  oriBetelgeuse 63fcb4a01e add image prompts jupyter 5 years ago
  oriBetelgeuse 8a9dbc4505 change version 5 years ago
  oriBetelgeuse f0af11941a add tests 5 years ago
  oriBetelgeuse 16e05efe52 add image_prompts to __init__ file 5 years ago
  oriBetelgeuse 82647573f4 add convert to RGB 5 years ago
  oriBetelgeuse 9812794e47 add border example 5 years ago
  oriBetelgeuse 0d4ee87a21 Merge branch 'master' of https://github.com/sberbank-ai/ru-dalle into feature/image_prompts 5 years ago
  oriBetelgeuse 654233a8cf Fix bugs 5 years ago
  oriBetelgeuse 1c95f7158c fix bugs 5 years ago
  oriBetelgeuse 0000ee7a3d add image_prompts 5 years ago
5 changed files with 36 additions and 22 deletions
Split View
  1. +1
    -1
      rudalle/__init__.py
  2. +10
    -3
      rudalle/dalle/transformer.py
  3. +21
    -8
      rudalle/image_prompts.py
  4. +1
    -3
      rudalle/pipelines.py
  5. +3
    -7
      tests/test_image_prompts.py

+ 1
- 1
rudalle/__init__.py View File

@ -22,4 +22,4 @@ __all__ = [
'image_prompts',
]
__version__ = '0.0.1-rc5'
__version__ = '0.0.1-rc6'

+ 10
- 3
rudalle/dalle/transformer.py View File

@ -214,6 +214,7 @@ class DalleSelfAttention(torch.nn.Module):
self.output_dropout = torch.nn.Dropout(output_dropout_prob)
# Cache
self.cash_size = 0
self.past_key = None
self.past_value = None
self.past_output = None
@ -253,7 +254,8 @@ class DalleSelfAttention(torch.nn.Module):
# ltor_mask: [1, 1, s, s]
# Attention heads. [b, s, hp]
if has_cache and use_cache:
mixed_x_layer = self.query_key_value(hidden_states[:, -1:, :])
extra_cache_size = hidden_states.shape[-2] - self.chache_size
mixed_x_layer = self.query_key_value(hidden_states[:, -extra_cache_size:, :])
else:
mixed_x_layer = self.query_key_value(hidden_states)
@ -278,6 +280,7 @@ class DalleSelfAttention(torch.nn.Module):
)
if use_cache:
self.chache_size = hidden_states.shape[-2]
self.past_key = key_layer
self.past_value = value_layer
else:
@ -287,7 +290,7 @@ class DalleSelfAttention(torch.nn.Module):
has_cache = False
if use_cache and has_cache:
attention_scores = attention_scores[..., -1:, :]
attention_scores = attention_scores[..., -extra_cache_size:, :]
# Attention probabilities. [b, np, s, s]
attention_probs = torch.nn.Softmax(dim=-1)(attention_scores)
@ -343,11 +346,14 @@ class DalleMLP(torch.nn.Module):
self.dense_4h_to_h = torch.nn.Linear(4*hidden_size, hidden_size)
self.dropout = torch.nn.Dropout(output_dropout_prob)
# MLP cache
self.cache_size = 0
self.past_x = None
def forward(self, hidden_states, has_cache=False, use_cache=False):
if has_cache and use_cache:
hidden_states = hidden_states[:, -1:]
extra_cache_size = hidden_states.shape[-2] - self.cache_size
self.cache_size += extra_cache_size
hidden_states = hidden_states[:, -extra_cache_size:]
# [b, s, 4hp]
x = self.dense_h_to_4h(hidden_states)
@ -360,6 +366,7 @@ class DalleMLP(torch.nn.Module):
x = torch.cat((self.past_x, x), dim=-2)
self.past_x = x
else:
self.cache_size = hidden_states.shape[-2]
self.past_x = x
has_cache = True


+ 21
- 8
rudalle/image_prompts.py View File

@ -18,7 +18,6 @@ class ImagePrompts:
self.device = device
img = self._preprocess_img(pil_image)
self.image_prompts_idx, self.image_prompts = self._get_image_prompts(img, borders, vae, crop_first)
self.allow_cache = True
def _preprocess_img(self, pil_img):
img = torch.tensor(np.array(pil_img.convert('RGB')).transpose(2, 0, 1)) / 255.
@ -26,17 +25,31 @@ class ImagePrompts:
img = (2 * img) - 1
return img
def _get_image_prompts(self, img, borders, vae, crop_first):
@staticmethod
def _get_image_prompts(img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
bs, _, img_w, img_h = img.shape
vqg_img_w, vqg_img_h = img_w // 8, img_h // 8
vqg_img = torch.zeros((bs, vqg_img_w, vqg_img_h), dtype=torch.int32, device=img.device)
if borders['down'] != 0:
down_border = borders['down'] * 8
_, _, [_, _, down_vqg_img] = vae.model.encode(img[:, :, -down_border:, :])
vqg_img[:, -borders['down']:, :] = down_vqg_img
if borders['right'] != 0:
right_border = borders['right'] * 8
_, _, [_, _, right_vqg_img] = vae.model.encode(img[:, :, :, :right_border])
vqg_img[:, :, :borders['right']] = right_vqg_img
if borders['left'] != 0:
left_border = borders['left'] * 8
_, _, [_, _, left_vqg_img] = vae.model.encode(img[:, :, :, -left_border:])
vqg_img[:, :, -borders['left']:] = left_vqg_img
if borders['up'] != 0:
up_border = borders['up'] * 8
_, _, [_, _, up_vqg_img] = vae.model.encode(img[:, :, :up_border, :])
vqg_img[:, :borders['up'], :] = up_vqg_img
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)
if borders['right'] + borders['left'] + borders['down'] != 0:
self.allow_cache = False # TODO fix cache in attention
bs, vqg_img_w, vqg_img_h = vqg_img.shape
mask = torch.zeros(vqg_img_w, vqg_img_h)
if borders['up'] != 0:


+ 1
- 3
rudalle/pipelines.py View File

@ -38,10 +38,8 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
sample_scores = []
if image_prompts is not None:
prompts_idx, prompts = image_prompts.image_prompts_idx, image_prompts.image_prompts
prompts = prompts.repeat(images_num, 1)
prompts = prompts.repeat(chunk_bs, 1)
if use_cache and image_prompts.allow_cache is False:
print('Warning: use_cache changed to False')
use_cache = False
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:


+ 3
- 7
tests/test_image_prompts.py View File

@ -13,10 +13,6 @@ def test_image_prompts(sample_image, vae, borders, crop_first):
img = sample_image.copy()
img = img.resize((256, 256))
image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
if crop_first:
assert image_prompt.image_prompts.shape[1] == borders['up'] * 32
assert len(image_prompt.image_prompts_idx) == borders['up'] * 32
else:
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])

Loading…
Cancel
Save