From a23a834806d5947f0586ef9288b1f4d124be75b0 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 7 Nov 2021 13:11:54 -0800 Subject: [PATCH] Add size param to show() (#40) --- rudalle/pipelines.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/rudalle/pipelines.py b/rudalle/pipelines.py index e37260b..169d3e0 100644 --- a/rudalle/pipelines.py +++ b/rudalle/pipelines.py @@ -85,10 +85,11 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu' return top_pil_images, top_scores -def show(pil_images, nrow=4, save_dir=None, show=True): +def show(pil_images, nrow=4, size=14, save_dir=None, show=True): """ :param pil_images: list of images in PIL :param nrow: number of rows + :param size: size of the images :param save_dir: dir for separately saving of images, example: save_dir='./pics' """ if save_dir is not None: @@ -100,7 +101,7 @@ def show(pil_images, nrow=4, save_dir=None, show=True): imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) if not isinstance(imgs, list): imgs = [imgs.cpu()] - fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) + fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size)) for i, img in enumerate(imgs): img = img.detach() img = torchvision.transforms.functional.to_pil_image(img)