[Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link
This commit is contained in:
NielsRogge
2023-04-13 15:01:14 +02:00
committed by GitHub
parent 95e7057507
commit 8eb38f638d
6 changed files with 96 additions and 113 deletions

View File

@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder:
expected_arg_names = [
"input_ids",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
]
expected_arg_names.extend(
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
else ["encoder_outputs"]
)
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
else:
expected_arg_names = (
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"]
)
self.assertListEqual(arg_names[:1], expected_arg_names)
expected_arg_names = [
"flattened_patches",
"attention_mask",
"decoder_input_ids",
"decoder_attention_mask",
"head_mask",
"decoder_head_mask",
"cross_attn_head_mask",
"encoder_outputs",
"past_key_values",
"labels",
"decoder_inputs_embeds",
"use_cache",
]
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
def test_training(self):
if not self.model_tester.is_training:
@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
)
def test_vqa_model(self):
model_id = "ybelkada/pix2struct-ai2d-base"
model_id = "google/pix2struct-ai2d-base"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(image_url, stream=True).raw)
@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
def test_vqa_model_batched(self):
model_id = "ybelkada/pix2struct-ai2d-base"
model_id = "google/pix2struct-ai2d-base"
image_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",