VLMs: enable generation tests - last batch (#34484)

* add tests for 3 more vlms

* fix fuyu back

* skip test
This commit is contained in:
Raushan Turganbay
2024-11-21 11:00:22 +01:00
committed by GitHub
parent 40821a2478
commit 28fb02fc05
6 changed files with 129 additions and 9 deletions

View File

@@ -21,7 +21,9 @@ import tempfile
import unittest
import numpy as np
import pytest
import requests
from parameterized import parameterized
from transformers import AutoModelForImageTextToText, AutoProcessor, Kosmos2Config
from transformers.models.kosmos2.configuration_kosmos2 import Kosmos2TextConfig, Kosmos2VisionConfig
@@ -37,6 +39,7 @@ from transformers.utils import (
is_vision_available,
)
from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
@@ -205,6 +208,7 @@ class Kosmos2ModelTester:
self.text_model_tester = Kosmos2TextModelTester(parent, **text_kwargs)
self.vision_model_tester = Kosmos2VisionModelTester(parent, **vision_kwargs)
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
self.seq_length = self.text_model_tester.seq_length
self.latent_query_num = latent_query_num
self.is_training = is_training
@@ -253,7 +257,7 @@ class Kosmos2ModelTester:
@require_torch
class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
class Kosmos2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Kosmos2Model, Kosmos2ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (Kosmos2ForConditionalGeneration,) if is_torch_available() else ()
pipeline_model_mapping = (
@@ -451,6 +455,68 @@ class Kosmos2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
@pytest.mark.generate
@parameterized.expand([("greedy", 1), ("beam search", 2)])
@unittest.skip(
"KOSMOS-2 doesn't support inputs embeds. The test isn't skipped by checking input args because KOSMOS-2 has `generate()` overwritten"
)
def test_generate_from_inputs_embeds(self):
pass
@pytest.mark.generate
def test_left_padding_compatibility(self):
# Overwrite because Kosmos-2 need to padd pixel values and pad image-attn-mask
def _prepare_model_kwargs(input_ids, attention_mask, pad_size, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
if "image_embeds_position_mask" in signature:
image_embeds_position_mask = torch.zeros_like(input_ids)
image_embeds_position_mask[:, (pad_size + 1) : pad_size + 1 + self.model_tester.latent_query_num] = 1
model_kwargs["image_embeds_position_mask"] = image_embeds_position_mask
return model_kwargs
for model_class in self.all_generative_model_classes:
config, inputs_dict = self.prepare_config_and_inputs_for_generate()
input_ids = inputs_dict["input_ids"]
pixel_values = inputs_dict["pixel_values"]
attention_mask = inputs_dict.get("attention_mask")
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()
# no cache as some models require special cache classes to be init outside forward
model.generation_config.use_cache = False
# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, pad_size=0, signature=signature)
next_logits_wo_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]
# With left-padding (length 32)
# can hardcode pad_token to be 0 as we'll do attn masking anyway
pad_token_id = (
config.get_text_config().pad_token_id if config.get_text_config().pad_token_id is not None else 0
)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(
padded_input_ids, padded_attention_mask, pad_size=32, signature=signature
)
next_logits_with_padding = model(**model_kwargs, pixel_values=pixel_values).logits[:, -1, :]
# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=1e-3))
@slow
def test_model_from_pretrained(self):
model_name = "microsoft/kosmos-2-patch14-224"