Fix pix2struct (#34374)

* fix

* fix and test use_cache test

* style

* remove atol
This commit is contained in:
Ilyas Moutawwakil
2024-10-28 11:24:56 +01:00
committed by GitHub
parent 1d06379331
commit fddbd3c13c
2 changed files with 46 additions and 25 deletions

View File

@@ -419,6 +419,7 @@ class Pix2StructModelTester:
@require_torch
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
@@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
),
)
def test_generative_model(self):
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_generative_model_classes:
model = model_class(config).eval().to(torch_device)
output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
torch.testing.assert_close(output, output_use_cache)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
def test_hidden_states_output(self):
pass