VLMs: enable generation tests - last batch (#34484)
* add tests for 3 more vlms * fix fuyu back * skip test
This commit is contained in:
committed by
GitHub
parent
40821a2478
commit
28fb02fc05
@@ -27,6 +27,7 @@ from transformers import Pix2StructConfig, Pix2StructTextConfig, Pix2StructVisio
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
from transformers.utils import is_torch_available, is_vision_available
|
||||
|
||||
from ...generation.test_utils import GenerationTesterMixin
|
||||
from ...test_configuration_common import ConfigTester
|
||||
from ...test_modeling_common import (
|
||||
ModelTesterMixin,
|
||||
@@ -388,6 +389,7 @@ class Pix2StructModelTester:
|
||||
self.batch_size = self.text_model_tester.batch_size # need bs for batching_equivalence test
|
||||
self.seq_length = self.text_model_tester.seq_length # need seq_length for common tests
|
||||
self.is_training = is_training
|
||||
self.max_patches = self.vision_model_tester.max_patches
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
@@ -417,7 +419,7 @@ class Pix2StructModelTester:
|
||||
|
||||
|
||||
@require_torch
|
||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
class Pix2StructModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
|
||||
pipeline_model_mapping = (
|
||||
@@ -751,6 +753,26 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
text_config = Pix2StructTextConfig.from_pretrained(tmp_dir_name)
|
||||
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
|
||||
|
||||
def _check_encoder_attention_for_generate(self, attentions, batch_size, config, seq_length):
|
||||
# overwrite because # pix2struct seq length depends on image inputs
|
||||
seq_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, config.num_attention_heads, seq_length, seq_length)
|
||||
self.assertIsInstance(attentions, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_attentions.shape for layer_attentions in attentions],
|
||||
[encoder_expected_shape] * len(attentions),
|
||||
)
|
||||
|
||||
def _check_encoder_hidden_states_for_generate(self, hidden_states, batch_size, config, seq_length):
|
||||
# overwrite because # pix2struct seq length depends on image inputs
|
||||
seq_length = self.model_tester.max_patches
|
||||
encoder_expected_shape = (batch_size, seq_length, config.hidden_size)
|
||||
self.assertIsInstance(hidden_states, tuple)
|
||||
self.assertListEqual(
|
||||
[layer_hidden_states.shape for layer_hidden_states in hidden_states],
|
||||
[encoder_expected_shape] * len(hidden_states),
|
||||
)
|
||||
|
||||
|
||||
# We will verify our results on an image of a stop sign
|
||||
def prepare_img():
|
||||
|
||||
Reference in New Issue
Block a user