# -*- coding: utf-8 -*- import torch import pytest from .test_vae import preprocess @pytest.mark.parametrize('text', [ 'мальчик играет с оленем', ]) def test_forward_step_and_criterion(text, sample_image, yttm_tokenizer, vae, small_dalle): bs = 4 text_seq_length = small_dalle.get_param('text_seq_length') total_seq_length = small_dalle.get_param('total_seq_length') device = small_dalle.get_param('device') img = sample_image.copy() img = preprocess(img, target_image_size=256) images = img.repeat(bs, 1, 1, 1).to(device) text = text.lower().strip() text_input_ids = yttm_tokenizer.encode_text(text, text_seq_length=text_seq_length) text_input_ids = text_input_ids.unsqueeze(0).repeat(bs, 1).to(device) attention_mask = torch.tril(torch.ones((bs, 1, total_seq_length, total_seq_length), device=device)) with torch.no_grad(): image_input_ids = vae.get_codebook_indices(images) input_ids = torch.cat((text_input_ids, image_input_ids), dim=1) loss, loss_values = small_dalle.forward(input_ids, attention_mask, return_loss=True) assert type(loss.data.detach().item()) == float assert type(loss_values) == dict