Browse Source

Add size param to show() (#40)

pull/45/head
Max Woolf 4 years ago
committed by GitHub
parent
commit
a23a834806
No known key found for this signature in database GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 3 additions and 2 deletions
  1. +3
    -2
      rudalle/pipelines.py

+ 3
- 2
rudalle/pipelines.py View File

@ -85,10 +85,11 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu'
return top_pil_images, top_scores
def show(pil_images, nrow=4, save_dir=None, show=True):
def show(pil_images, nrow=4, size=14, save_dir=None, show=True):
"""
:param pil_images: list of images in PIL
:param nrow: number of rows
:param size: size of the images
:param save_dir: dir for separately saving of images, example: save_dir='./pics'
"""
if save_dir is not None:
@ -100,7 +101,7 @@ def show(pil_images, nrow=4, save_dir=None, show=True):
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow)
if not isinstance(imgs, list):
imgs = [imgs.cpu()]
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14))
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size))
for i, img in enumerate(imgs):
img = img.detach()
img = torchvision.transforms.functional.to_pil_image(img)


Loading…
Cancel
Save