Browse Source

add tests

pull/2/head
oriBetelgeuse 5 years ago
parent
commit
f0af11941a
2 changed files with 22 additions and 2 deletions
  1. +2
    -2
      rudalle/__init__.py
  2. +20
    -0
      tests/test_image_prompts.py

+ 2
- 2
rudalle/__init__.py View File

@ -4,7 +4,6 @@ from .dalle import get_rudalle_model
from .tokenizer import get_tokenizer
from .realesrgan import get_realesrgan
from .ruclip import get_ruclip
from .image_prompts import ImagePrompts
from . import vae, dalle, tokenizer, realesrgan, pipelines, ruclip, image_prompts
@ -20,7 +19,8 @@ __all__ = [
'tokenizer',
'realesrgan',
'pipelines',
'ImagePrompts',
'image_prompts',
]
__version__ = '0.0.1-rc1'

+ 20
- 0
tests/test_image_prompts.py View File

@ -0,0 +1,20 @@
import pytest
from rudalle.image_prompts import ImagePrompts
@pytest.mark.parametrize('borders, crop_first', [
({'up': 4, 'right': 0, 'left': 0, 'down': 0}, False),
({'up': 4, 'right': 0, 'left': 0, 'down': 0}, True),
({'up': 4, 'right': 3, 'left': 3, 'down': 3}, False)
])
def test_image_prompts(sample_image, vae, borders, crop_first):
img = sample_image.copy()
img = img.resize((256, 256))
image_prompt = ImagePrompts(img, borders, vae, crop_first=crop_first)
if crop_first:
assert image_prompt.image_prompts.shape[1] == borders['up'] * 32
assert len(image_prompt.image_prompts_idx) == borders['up'] * 32
else:
assert image_prompt.image_prompts.shape[1] == 32 * 32
assert len(image_prompt.image_prompts_idx) == (borders['up'] + borders['down']) * 32 \
+ (borders['left'] + borders['right']) * (32 - borders['up'] - borders['down'])

Loading…
Cancel
Save