[Pix2struct] Simplify generation (#22527)
* Add model to doc tests * Remove generate and replace by prepare_inputs_for_generation * More fixes * Remove print statements * Update integration tests * Fix generate * Remove model from auto mapping * Use auto processor * Fix integration tests * Fix test * Add inference code snippet * Remove is_encoder_decoder * Update docs * Remove notebook link
This commit is contained in:
@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||
arg_names = [*signature.parameters.keys()]
|
||||
|
||||
if model.config.is_encoder_decoder:
|
||||
expected_arg_names = [
|
||||
"input_ids",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
]
|
||||
expected_arg_names.extend(
|
||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
||||
else ["encoder_outputs"]
|
||||
)
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
else:
|
||||
expected_arg_names = (
|
||||
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"]
|
||||
)
|
||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
||||
expected_arg_names = [
|
||||
"flattened_patches",
|
||||
"attention_mask",
|
||||
"decoder_input_ids",
|
||||
"decoder_attention_mask",
|
||||
"head_mask",
|
||||
"decoder_head_mask",
|
||||
"cross_attn_head_mask",
|
||||
"encoder_outputs",
|
||||
"past_key_values",
|
||||
"labels",
|
||||
"decoder_inputs_embeds",
|
||||
"use_cache",
|
||||
]
|
||||
|
||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||
|
||||
def test_training(self):
|
||||
if not self.model_tester.is_training:
|
||||
@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
def test_vqa_model(self):
|
||||
model_id = "ybelkada/pix2struct-ai2d-base"
|
||||
model_id = "google/pix2struct-ai2d-base"
|
||||
|
||||
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||
@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
||||
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
|
||||
|
||||
def test_vqa_model_batched(self):
|
||||
model_id = "ybelkada/pix2struct-ai2d-base"
|
||||
model_id = "google/pix2struct-ai2d-base"
|
||||
|
||||
image_urls = [
|
||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
|
||||
|
||||
Reference in New Issue
Block a user