Fix pix2struct (#34374)
* fix * fix and test use_cache test * style * remove atol
This commit is contained in:
committed by
GitHub
parent
1d06379331
commit
fddbd3c13c
@@ -419,6 +419,7 @@ class Pix2StructModelTester:
|
||||
@require_torch
|
||||
class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else ()
|
||||
all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {}
|
||||
pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {}
|
||||
fx_compatible = False
|
||||
test_head_masking = False
|
||||
@@ -445,6 +446,16 @@ class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
|
||||
),
|
||||
)
|
||||
|
||||
def test_generative_model(self):
|
||||
config, input_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_generative_model_classes:
|
||||
model = model_class(config).eval().to(torch_device)
|
||||
|
||||
output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10)
|
||||
output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10)
|
||||
|
||||
torch.testing.assert_close(output, output_use_cache)
|
||||
|
||||
@unittest.skip(reason="Hidden_states is tested in individual model tests")
|
||||
def test_hidden_states_output(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user