diff --git a/rudalle/dalle/image_attention.py b/rudalle/dalle/image_attention.py index 0df1a77..038b8a1 100644 --- a/rudalle/dalle/image_attention.py +++ b/rudalle/dalle/image_attention.py @@ -5,7 +5,7 @@ import torch def _init_mask(text_tokens, image_tokens_per_dim): attn_size = text_tokens + image_tokens_per_dim**2 - mask = torch.tril(torch.ones(attn_size, attn_size)) + mask = torch.tril(torch.ones(attn_size, attn_size, dtype=torch.bool)) return mask @@ -13,7 +13,7 @@ def get_row_mask(text_tokens=256, image_tokens_per_dim=32): mask = _init_mask(text_tokens, image_tokens_per_dim) step = image_tokens_per_dim + 1 for col in range(text_tokens, mask.size(1)): - mask[col + step:, col] = 0.0 + mask[col + step:, col] = False return mask @@ -22,7 +22,7 @@ def get_col_mask(text_tokens=256, image_tokens_per_dim=32): step = image_tokens_per_dim - 1 for col in range(text_tokens, mask.size(1)): for i in range(1, mask.size(0), step+1): - mask[col + i: col + i + step, col] = 0.0 + mask[col + i: col + i + step, col] = False return mask @@ -30,7 +30,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): mask = _init_mask(text_tokens, image_tokens_per_dim) shift = kernel // 2 for pos in range(text_tokens, mask.size(1)): - mask[pos+1:, pos] = 0.0 + mask[pos+1:, pos] = False img = torch.zeros(image_tokens_per_dim, image_tokens_per_dim) pixel_id = pos - text_tokens row = pixel_id // image_tokens_per_dim @@ -42,7 +42,7 @@ def get_conv_mask(text_tokens=256, image_tokens_per_dim=32, kernel=11): img[r_abs, c_abs] = 0.2 cell_id = r_abs * image_tokens_per_dim + c_abs if text_tokens + cell_id > pos: - mask[text_tokens + cell_id, pos] = 1.0 + mask[text_tokens + cell_id, pos] = True img[row, col] = 1.0 return mask