|
|
|
@ -5,7 +5,7 @@ import torch |
|
|
|
|
|
|
|
def _init_mask(text_tokens, image_tokens_per_dim): |
|
|
|
attn_size = text_tokens + image_tokens_per_dim**2 |
|
|
|
mask = torch.tril(torch.ones(attn_size, attn_size)) |
|
|
|
mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool)) |
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
@ -13,7 +13,7 @@ def get_row_mask(text_tokens=256, image_tokens_per_dim=32): |
|
|
|
mask = _init_mask(text_tokens, image_tokens_per_dim) |
|
|
|
step = image_tokens_per_dim + 1 |
|
|
|
for col in range(text_tokens, mask.size(1)): |
|
|
|
mask[col + step:, col] = 0.0 |
|
|
|
mask[col + step:, col] = False |
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
@ -22,7 +22,7 @@ def get_col_mask(text_tokens=256, image_tokens_per_dim=32): |
|
|
|
step = image_tokens_per_dim - 1 |
|
|
|
for col in range(text_tokens, mask.size(1)): |
|
|
|
for i in range(1, mask.size(0), step+1): |
|
|
|
mask[col + i: col + i + step, col] = 0.0 |
|
|
|
mask[col + i: col + i + step, col] = False |
|
|
|
return mask |
|
|
|
|
|
|
|
|
|
|
|
@ -30,7 +30,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): |
|
|
|
mask = _init_mask(text_tokens, image_tokens_per_dim) |
|
|
|
shift = kernel // 2 |
|
|
|
for pos in range(text_tokens, mask.size(1)): |
|
|
|
mask[pos+1:, pos] = 0.0 |
|
|
|
mask[pos+1:, pos] = False |
|
|
|
img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) |
|
|
|
pixel_id = pos - text_tokens |
|
|
|
row = pixel_id // image_tokens_per_dim |
|
|
|
@ -42,7 +42,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): |
|
|
|
img[r_abs, c_abs] = 0.2 |
|
|
|
cell_id = r_abs * image_tokens_per_dim + c_abs |
|
|
|
if text_tokens + cell_id > pos: |
|
|
|
mask[text_tokens + cell_id, pos] = 1.0 |
|
|
|
mask[text_tokens + cell_id, pos] = True |
|
|
|
|
|
|
|
img[row, col] = 1.0 |
|
|
|
return mask |