|
|
|
@ -85,10 +85,11 @@ def cherry_pick_by_clip(pil_images, text, ruclip, ruclip_processor, device='cpu' |
|
|
|
return top_pil_images, top_scores |
|
|
|
|
|
|
|
|
|
|
|
def show(pil_images, nrow=4, save_dir=None, show=True): |
|
|
|
def show(pil_images, nrow=4, size=14, save_dir=None, show=True): |
|
|
|
""" |
|
|
|
:param pil_images: list of images in PIL |
|
|
|
:param nrow: number of rows |
|
|
|
:param size: size of the images |
|
|
|
:param save_dir: dir for separately saving of images, example: save_dir='./pics' |
|
|
|
""" |
|
|
|
if save_dir is not None: |
|
|
|
@ -100,7 +101,7 @@ def show(pil_images, nrow=4, save_dir=None, show=True): |
|
|
|
imgs = torchvision.utils.make_grid(utils.pil_list_to_torch_tensors(pil_images), nrow=nrow) |
|
|
|
if not isinstance(imgs, list): |
|
|
|
imgs = [imgs.cpu()] |
|
|
|
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(14, 14)) |
|
|
|
fix, axs = plt.subplots(ncols=len(imgs), squeeze=False, figsize=(size, size)) |
|
|
|
for i, img in enumerate(imgs): |
|
|
|
img = img.detach() |
|
|
|
img = torchvision.transforms.functional.to_pil_image(img) |
|
|
|
|