Browse Source

Update image_attention.py

pull/49/head
blue-fish 4 years ago
committed by GitHub
parent
commit
c136378e42
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 5 deletions
  1. +5
    -5
      rudalle/dalle/image_attention.py

+ 5
- 5
rudalle/dalle/image_attention.py View File

@ -5,7 +5,7 @@ import torch
def _init_mask(text_tokens, image_tokens_per_dim):
attn_size = text_tokens + image_tokens_per_dim**2
mask = torch.tril(torch.ones(attn_size, attn_size))
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool))
return mask
@ -13,7 +13,7 @@ def get_row_mask(text_tokens=256, image_tokens_per_dim=32):
mask = _init_mask(text_tokens, image_tokens_per_dim)
step = image_tokens_per_dim + 1
for col in range(text_tokens, mask.size(1)):
mask[col + step:, col] = 0.0
mask[col + step:, col] = False
return mask
@ -22,7 +22,7 @@ def get_col_mask(text_tokens=256, image_tokens_per_dim=32):
step = image_tokens_per_dim - 1
for col in range(text_tokens, mask.size(1)):
for i in range(1, mask.size(0), step+1):
mask[col + i: col + i + step, col] = 0.0
mask[col + i: col + i + step, col] = False
return mask
@ -30,7 +30,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11):
mask = _init_mask(text_tokens, image_tokens_per_dim)
shift = kernel // 2
for pos in range(text_tokens, mask.size(1)):
mask[pos+1:, pos] = 0.0
mask[pos+1:, pos] = False
img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim)
pixel_id = pos - text_tokens
row = pixel_id // image_tokens_per_dim
@ -42,7 +42,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11):
img[r_abs, c_abs] = 0.2
cell_id = r_abs * image_tokens_per_dim + c_abs
if text_tokens + cell_id > pos:
mask[text_tokens + cell_id, pos] = 1.0
mask[text_tokens + cell_id, pos] = True
img[row, col] = 1.0
return mask

Loading…
Cancel
Save