[Pix2struct] Simplify generation (#22527)

* Add model to doc tests

* Remove generate and replace by prepare_inputs_for_generation

* More fixes

* Remove print statements

* Update integration tests

* Fix generate

* Remove model from auto mapping

* Use auto processor

* Fix integration tests

* Fix test

* Add inference code snippet

* Remove is_encoder_decoder

* Update docs

* Remove notebook link
This commit is contained in:
NielsRogge
2023-04-13 15:01:14 +02:00
committed by GitHub
parent 95e7057507
commit 8eb38f638d
6 changed files with 96 additions and 113 deletions

View File

@@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu
This model was contributed by [ybelkada](https://huggingface.co/ybelkada). This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
The original code can be found [here](https://github.com/google-research/pix2struct). The original code can be found [here](https://github.com/google-research/pix2struct).
## Resources: ## Resources
- [Paper](https://arxiv.org/abs/2210.03347)
- [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb) - [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
- [All models](https://huggingface.co/models?search=pix2struct) - [All models](https://huggingface.co/models?search=pix2struct)

View File

@@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin):
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the # Special case: some models have generation attributes set in the decoder. Use them if still unset in the
# generation config. # generation config.
for decoder_name in ("decoder", "generator"): for decoder_name in ("decoder", "generator", "text_config"):
if decoder_name in config_dict: if decoder_name in config_dict:
default_generation_config = GenerationConfig() default_generation_config = GenerationConfig()
decoder_config = config_dict[decoder_name] decoder_config = config_dict[decoder_name]

View File

@@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig):
initializer_range=0.02, initializer_range=0.02,
is_vqa=False, is_vqa=False,
tie_word_embeddings=False, tie_word_embeddings=False,
is_encoder_decoder=True,
**kwargs, **kwargs,
): ):
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs)
if text_config is None: if text_config is None:
text_config = {} text_config = {}
@@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig):
self.text_config = Pix2StructTextConfig(**text_config) self.text_config = Pix2StructTextConfig(**text_config)
self.vision_config = Pix2StructVisionConfig(**vision_config) self.vision_config = Pix2StructVisionConfig(**vision_config)
self.text_config.encoder_hidden_size = self.vision_config.hidden_size
self.decoder_start_token_id = self.text_config.decoder_start_token_id self.decoder_start_token_id = self.text_config.decoder_start_token_id
self.pad_token_id = self.text_config.pad_token_id self.pad_token_id = self.text_config.pad_token_id
self.eos_token_id = self.text_config.eos_token_id
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.initializer_range = initializer_range self.initializer_range = initializer_range

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
""" Pix2Struct modeling file""" """ Pix2Struct modeling file"""
import copy
import math import math
from typing import Dict, List, Optional, Tuple, Union from typing import Dict, List, Optional, Tuple, Union
@@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
cross_attentions=all_cross_attentions, cross_attentions=all_cross_attentions,
) )
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
input_shape = input_ids.shape
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)
# cut decoder_input_ids if past_key_values is used
if past_key_values is not None:
input_ids = input_ids[:, -1:]
return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"past_key_values": past_key_values,
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
"is_decoder": True,
}
@add_start_docstrings( @add_start_docstrings(
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
@@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
def __init__(self, config: Pix2StructConfig): def __init__(self, config: Pix2StructConfig):
super().__init__(config) super().__init__(config)
encoder_config = copy.deepcopy(config.vision_config)
self.encoder = Pix2StructVisionModel(encoder_config)
decoder_config = copy.deepcopy(config.text_config) self.encoder = Pix2StructVisionModel(config.vision_config)
self.decoder_start_token_id = decoder_config.pad_token_id self.decoder = Pix2StructTextModel(config.text_config)
self.decoder_eos_token_ids = decoder_config.eos_token_id
self.decoder = Pix2StructTextModel(decoder_config)
self.is_vqa = config.is_vqa self.is_vqa = config.is_vqa
@@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
Example: Example:
Inference:
```python ```python
>>> from PIL import Image >>> from PIL import Image
>>> import requests >>> import requests
@@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base") >>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base") >>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> labels = "A stop sign is on the street corner."
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg" >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw) >>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=labels, return_tensors="pt", add_special_tokens=True) >>> inputs = processor(images=image, return_tensors="pt")
>>> # autoregressive generation
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
A stop sign is on a street corner.
```
Training:
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> text = "A stop sign is on the street corner."
>>> inputs = processor(images=image, return_tensors="pt")
>>> labels = processor(text=text, return_tensors="pt").input_ids
>>> # forward pass >>> # forward pass
>>> outputs = model(**inputs) >>> outputs = model(**inputs, labels=labels)
>>> last_hidden_states = outputs.loss >>> loss = outputs.loss
>>> print(loss.item())
5.239729881286621
```""" ```"""
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
encoder_attentions=encoder_outputs.attentions, encoder_attentions=encoder_outputs.attentions,
) )
@torch.no_grad() def prepare_inputs_for_generation(
def generate(
self, self,
flattened_patches: torch.FloatTensor, input_ids,
decoder_input_ids: Optional[torch.LongTensor] = None, flattened_patches: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None, attention_mask: Optional[torch.FloatTensor] = None,
decoder_attention_mask: Optional[torch.LongTensor] = None, decoder_attention_mask: Optional[torch.BoolTensor] = None,
**generate_kwargs, past_key_values=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
): ):
r""" if isinstance(input_ids, torch.Tensor):
Returns: # check if the first element of `input_ids` is equal to `input_ids`:
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
Example: # add `input_ids` as first token to `input_ids`
input_ids = torch.cat(
```python
>>> from PIL import Image
>>> import requests
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
>>> conditional_text = "A stop sign"
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> inputs = processor(images=image, text=conditional_text, return_tensors="pt", add_special_tokens=True)
>>> # forward pass
>>> outputs = model.generate(**inputs)
>>> print(processor.batch_decode(outputs, skip_special_tokens=True))
['A stop sign the street with a sign that says yes']
```"""
batch_size, _, _ = flattened_patches.shape
vision_outputs = self.encoder(flattened_patches=flattened_patches, attention_mask=attention_mask)
image_embeds = vision_outputs[0]
if isinstance(decoder_input_ids, torch.Tensor):
# check if the first element of `input_ids` is equal to `decoder_input_ids`:
if (decoder_input_ids[:, 0] != self.decoder_start_token_id).all().item():
# add `decoder_input_ids` as first token to `input_ids`
decoder_input_ids = torch.cat(
[ [
torch.ones((decoder_input_ids.shape[0], 1), dtype=torch.long, device=decoder_input_ids.device) torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
* self.decoder_start_token_id, * self.config.decoder_start_token_id,
decoder_input_ids, input_ids,
], ],
dim=-1, dim=-1,
) )
@@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
], ],
dim=-1, dim=-1,
) )
elif decoder_input_ids is None: elif input_ids is None:
decoder_input_ids = ( batch_size = flattened_patches.shape[0]
torch.LongTensor([[self.decoder_start_token_id]]).repeat(batch_size, 1).to(image_embeds.device) input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
)
if decoder_attention_mask is None: if decoder_attention_mask is None:
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(image_embeds.device) decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
outputs = self.decoder.generate( # cut decoder_input_ids if past is used
input_ids=decoder_input_ids, if past_key_values is not None:
attention_mask=decoder_attention_mask, input_ids = input_ids[:, -1:]
encoder_hidden_states=image_embeds,
encoder_attention_mask=attention_mask,
**generate_kwargs,
)
return outputs return {
"flattened_patches": flattened_patches,
"decoder_input_ids": input_ids,
"past_key_values": past_key_values,
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"decoder_attention_mask": decoder_attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache,
}

View File

@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
# signature.parameters is an OrderedDict => so arg_names order is deterministic # signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()] arg_names = [*signature.parameters.keys()]
if model.config.is_encoder_decoder: expected_arg_names = [
expected_arg_names = [ "flattened_patches",
"input_ids", "attention_mask",
"attention_mask", "decoder_input_ids",
"decoder_input_ids", "decoder_attention_mask",
"decoder_attention_mask", "head_mask",
] "decoder_head_mask",
expected_arg_names.extend( "cross_attn_head_mask",
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"] "encoder_outputs",
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names "past_key_values",
else ["encoder_outputs"] "labels",
) "decoder_inputs_embeds",
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names) "use_cache",
else: ]
expected_arg_names = (
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"] self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
)
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_training(self): def test_training(self):
if not self.model_tester.is_training: if not self.model_tester.is_training:
@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
) )
def test_vqa_model(self): def test_vqa_model(self):
model_id = "ybelkada/pix2struct-ai2d-base" model_id = "google/pix2struct-ai2d-base"
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg" image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
image = Image.open(requests.get(image_url, stream=True).raw) image = Image.open(requests.get(image_url, stream=True).raw)
@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud") self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
def test_vqa_model_batched(self): def test_vqa_model_batched(self):
model_id = "ybelkada/pix2struct-ai2d-base" model_id = "google/pix2struct-ai2d-base"
image_urls = [ image_urls = [
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg", "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",

View File

@@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py
src/transformers/models/pegasus/tokenization_pegasus_fast.py src/transformers/models/pegasus/tokenization_pegasus_fast.py
src/transformers/models/perceiver/tokenization_perceiver.py src/transformers/models/perceiver/tokenization_perceiver.py
src/transformers/models/phobert/tokenization_phobert.py src/transformers/models/phobert/tokenization_phobert.py
src/transformers/models/pix2struct/modeling_pix2struct.py
src/transformers/models/plbart/tokenization_plbart.py src/transformers/models/plbart/tokenization_plbart.py
src/transformers/models/prophetnet/tokenization_prophetnet.py src/transformers/models/prophetnet/tokenization_prophetnet.py
src/transformers/models/rag/tokenization_rag.py src/transformers/models/rag/tokenization_rag.py