Browse Source

Fix bugs

pull/2/head
oriBetelgeuse 5 years ago
parent
commit
654233a8cf
2 changed files with 3 additions and 3 deletions
  1. +2
    -2
      rudalle/image_prompts.py
  2. +1
    -1
      rudalle/pipelines.py

+ 2
- 2
rudalle/image_prompts.py View File

@ -24,10 +24,10 @@ class ImagePrompts:
return img
@staticmethod
def _get_image_prompts(img, borders, vae, crop_first=False):
def _get_image_prompts(img, borders, vae, crop_first):
if crop_first:
assert borders['right'] + borders['left'] + borders['down'] == 0
up_border = borders['up'] * 7
up_border = borders['up'] * 8
_, _, [_, _, vqg_img] = vae.model.encode(img[:, :, :up_border, :])
else:
_, _, [_, _, vqg_img] = vae.model.encode(img)


+ 1
- 1
rudalle/pipelines.py View File

@ -39,7 +39,7 @@ def generate_images(text, tokenizer, dalle, vae, top_k, top_p, images_num, image
use_cache = False
print("Warning: use_cache changed to False")
for idx in tqdm(range(out.shape[1], total_seq_length)):
idx -= text_seq_lengths
idx -= text_seq_length
if image_prompts is not None and idx in prompts_idx:
out = torch.cat((out, prompts[:, idx].unsqueeze(1)), dim=-1)
else:


Loading…
Cancel
Save