[Pix2struct] Simplify generation (#22527)
* Add model to doc tests * Remove generate and replace by prepare_inputs_for_generation * More fixes * Remove print statements * Update integration tests * Fix generate * Remove model from auto mapping * Use auto processor * Fix integration tests * Fix test * Add inference code snippet * Remove is_encoder_decoder * Update docs * Remove notebook link
This commit is contained in:
@@ -28,9 +28,8 @@ We therefore advise you to use these models for the tasks they have been fine tu
|
|||||||
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
|
This model was contributed by [ybelkada](https://huggingface.co/ybelkada).
|
||||||
The original code can be found [here](https://github.com/google-research/pix2struct).
|
The original code can be found [here](https://github.com/google-research/pix2struct).
|
||||||
|
|
||||||
## Resources:
|
## Resources
|
||||||
|
|
||||||
- [Paper](https://arxiv.org/abs/2210.03347)
|
|
||||||
- [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
|
- [Fine-tuning Notebook](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
|
||||||
- [All models](https://huggingface.co/models?search=pix2struct)
|
- [All models](https://huggingface.co/models?search=pix2struct)
|
||||||
|
|
||||||
@@ -70,4 +69,4 @@ The original code can be found [here](https://github.com/google-research/pix2str
|
|||||||
## Pix2StructForConditionalGeneration
|
## Pix2StructForConditionalGeneration
|
||||||
|
|
||||||
[[autodoc]] Pix2StructForConditionalGeneration
|
[[autodoc]] Pix2StructForConditionalGeneration
|
||||||
- forward
|
- forward
|
||||||
@@ -681,7 +681,7 @@ class GenerationConfig(PushToHubMixin):
|
|||||||
|
|
||||||
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
# Special case: some models have generation attributes set in the decoder. Use them if still unset in the
|
||||||
# generation config.
|
# generation config.
|
||||||
for decoder_name in ("decoder", "generator"):
|
for decoder_name in ("decoder", "generator", "text_config"):
|
||||||
if decoder_name in config_dict:
|
if decoder_name in config_dict:
|
||||||
default_generation_config = GenerationConfig()
|
default_generation_config = GenerationConfig()
|
||||||
decoder_config = config_dict[decoder_name]
|
decoder_config = config_dict[decoder_name]
|
||||||
|
|||||||
@@ -358,9 +358,10 @@ class Pix2StructConfig(PretrainedConfig):
|
|||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
is_vqa=False,
|
is_vqa=False,
|
||||||
tie_word_embeddings=False,
|
tie_word_embeddings=False,
|
||||||
|
is_encoder_decoder=True,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
super().__init__(tie_word_embeddings=tie_word_embeddings, is_encoder_decoder=is_encoder_decoder, **kwargs)
|
||||||
|
|
||||||
if text_config is None:
|
if text_config is None:
|
||||||
text_config = {}
|
text_config = {}
|
||||||
@@ -373,9 +374,9 @@ class Pix2StructConfig(PretrainedConfig):
|
|||||||
self.text_config = Pix2StructTextConfig(**text_config)
|
self.text_config = Pix2StructTextConfig(**text_config)
|
||||||
self.vision_config = Pix2StructVisionConfig(**vision_config)
|
self.vision_config = Pix2StructVisionConfig(**vision_config)
|
||||||
|
|
||||||
self.text_config.encoder_hidden_size = self.vision_config.hidden_size
|
|
||||||
self.decoder_start_token_id = self.text_config.decoder_start_token_id
|
self.decoder_start_token_id = self.text_config.decoder_start_token_id
|
||||||
self.pad_token_id = self.text_config.pad_token_id
|
self.pad_token_id = self.text_config.pad_token_id
|
||||||
|
self.eos_token_id = self.text_config.eos_token_id
|
||||||
|
|
||||||
self.initializer_factor = initializer_factor
|
self.initializer_factor = initializer_factor
|
||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Pix2Struct modeling file"""
|
""" Pix2Struct modeling file"""
|
||||||
|
|
||||||
import copy
|
|
||||||
import math
|
import math
|
||||||
from typing import Dict, List, Optional, Tuple, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -1580,25 +1579,6 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
|||||||
cross_attentions=all_cross_attentions,
|
cross_attentions=all_cross_attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
|
|
||||||
input_shape = input_ids.shape
|
|
||||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
||||||
if attention_mask is None:
|
|
||||||
attention_mask = input_ids.new_ones(input_shape)
|
|
||||||
|
|
||||||
# cut decoder_input_ids if past_key_values is used
|
|
||||||
if past_key_values is not None:
|
|
||||||
input_ids = input_ids[:, -1:]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"input_ids": input_ids,
|
|
||||||
"attention_mask": attention_mask,
|
|
||||||
"past_key_values": past_key_values,
|
|
||||||
"encoder_hidden_states": model_kwargs.get("encoder_hidden_states", None),
|
|
||||||
"encoder_attention_mask": model_kwargs.get("encoder_attention_mask", None),
|
|
||||||
"is_decoder": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@add_start_docstrings(
|
@add_start_docstrings(
|
||||||
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
|
"A conditional generation model with a language modeling head. Can be used for sequence generation tasks.",
|
||||||
@@ -1618,13 +1598,9 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
|
|
||||||
def __init__(self, config: Pix2StructConfig):
|
def __init__(self, config: Pix2StructConfig):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
encoder_config = copy.deepcopy(config.vision_config)
|
|
||||||
self.encoder = Pix2StructVisionModel(encoder_config)
|
|
||||||
|
|
||||||
decoder_config = copy.deepcopy(config.text_config)
|
self.encoder = Pix2StructVisionModel(config.vision_config)
|
||||||
self.decoder_start_token_id = decoder_config.pad_token_id
|
self.decoder = Pix2StructTextModel(config.text_config)
|
||||||
self.decoder_eos_token_ids = decoder_config.eos_token_id
|
|
||||||
self.decoder = Pix2StructTextModel(decoder_config)
|
|
||||||
|
|
||||||
self.is_vqa = config.is_vqa
|
self.is_vqa = config.is_vqa
|
||||||
|
|
||||||
@@ -1682,6 +1658,8 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
|
Inference:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from PIL import Image
|
>>> from PIL import Image
|
||||||
>>> import requests
|
>>> import requests
|
||||||
@@ -1690,15 +1668,40 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
||||||
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
||||||
|
|
||||||
>>> labels = "A stop sign is on the street corner."
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=labels, return_tensors="pt", add_special_tokens=True)
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||||||
|
|
||||||
|
>>> # autoregressive generation
|
||||||
|
>>> generated_ids = model.generate(**inputs, max_new_tokens=50)
|
||||||
|
>>> generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
||||||
|
>>> print(generated_text)
|
||||||
|
A stop sign is on a street corner.
|
||||||
|
```
|
||||||
|
|
||||||
|
Training:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> from PIL import Image
|
||||||
|
>>> import requests
|
||||||
|
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
|
||||||
|
|
||||||
|
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-base")
|
||||||
|
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-base")
|
||||||
|
|
||||||
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
||||||
|
>>> image = Image.open(requests.get(url, stream=True).raw)
|
||||||
|
>>> text = "A stop sign is on the street corner."
|
||||||
|
|
||||||
|
>>> inputs = processor(images=image, return_tensors="pt")
|
||||||
|
>>> labels = processor(text=text, return_tensors="pt").input_ids
|
||||||
|
|
||||||
>>> # forward pass
|
>>> # forward pass
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs, labels=labels)
|
||||||
>>> last_hidden_states = outputs.loss
|
>>> loss = outputs.loss
|
||||||
|
>>> print(loss.item())
|
||||||
|
5.239729881286621
|
||||||
```"""
|
```"""
|
||||||
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
|
use_cache = use_cache if use_cache is not None else self.config.text_config.use_cache
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
@@ -1759,54 +1762,29 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
encoder_attentions=encoder_outputs.attentions,
|
encoder_attentions=encoder_outputs.attentions,
|
||||||
)
|
)
|
||||||
|
|
||||||
@torch.no_grad()
|
def prepare_inputs_for_generation(
|
||||||
def generate(
|
|
||||||
self,
|
self,
|
||||||
flattened_patches: torch.FloatTensor,
|
input_ids,
|
||||||
decoder_input_ids: Optional[torch.LongTensor] = None,
|
flattened_patches: Optional[torch.FloatTensor] = None,
|
||||||
attention_mask: Optional[torch.FloatTensor] = None,
|
attention_mask: Optional[torch.FloatTensor] = None,
|
||||||
decoder_attention_mask: Optional[torch.LongTensor] = None,
|
decoder_attention_mask: Optional[torch.BoolTensor] = None,
|
||||||
**generate_kwargs,
|
past_key_values=None,
|
||||||
|
head_mask=None,
|
||||||
|
decoder_head_mask=None,
|
||||||
|
cross_attn_head_mask=None,
|
||||||
|
use_cache=None,
|
||||||
|
encoder_outputs=None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
r"""
|
if isinstance(input_ids, torch.Tensor):
|
||||||
Returns:
|
# check if the first element of `input_ids` is equal to `input_ids`:
|
||||||
|
if (input_ids[:, 0] != self.config.decoder_start_token_id).all().item():
|
||||||
Example:
|
# add `input_ids` as first token to `input_ids`
|
||||||
|
input_ids = torch.cat(
|
||||||
```python
|
|
||||||
>>> from PIL import Image
|
|
||||||
>>> import requests
|
|
||||||
>>> from transformers import AutoProcessor, Pix2StructForConditionalGeneration
|
|
||||||
|
|
||||||
>>> processor = AutoProcessor.from_pretrained("google/pix2struct-textcaps-base")
|
|
||||||
>>> model = Pix2StructForConditionalGeneration.from_pretrained("google/pix2struct-textcaps-base")
|
|
||||||
|
|
||||||
>>> conditional_text = "A stop sign"
|
|
||||||
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
||||||
>>> image = Image.open(requests.get(url, stream=True).raw)
|
|
||||||
|
|
||||||
>>> inputs = processor(images=image, text=conditional_text, return_tensors="pt", add_special_tokens=True)
|
|
||||||
|
|
||||||
>>> # forward pass
|
|
||||||
>>> outputs = model.generate(**inputs)
|
|
||||||
>>> print(processor.batch_decode(outputs, skip_special_tokens=True))
|
|
||||||
['A stop sign the street with a sign that says yes']
|
|
||||||
```"""
|
|
||||||
batch_size, _, _ = flattened_patches.shape
|
|
||||||
|
|
||||||
vision_outputs = self.encoder(flattened_patches=flattened_patches, attention_mask=attention_mask)
|
|
||||||
|
|
||||||
image_embeds = vision_outputs[0]
|
|
||||||
|
|
||||||
if isinstance(decoder_input_ids, torch.Tensor):
|
|
||||||
# check if the first element of `input_ids` is equal to `decoder_input_ids`:
|
|
||||||
if (decoder_input_ids[:, 0] != self.decoder_start_token_id).all().item():
|
|
||||||
# add `decoder_input_ids` as first token to `input_ids`
|
|
||||||
decoder_input_ids = torch.cat(
|
|
||||||
[
|
[
|
||||||
torch.ones((decoder_input_ids.shape[0], 1), dtype=torch.long, device=decoder_input_ids.device)
|
torch.ones((input_ids.shape[0], 1), dtype=torch.long, device=input_ids.device)
|
||||||
* self.decoder_start_token_id,
|
* self.config.decoder_start_token_id,
|
||||||
decoder_input_ids,
|
input_ids,
|
||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
@@ -1823,20 +1801,26 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
],
|
],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
)
|
)
|
||||||
elif decoder_input_ids is None:
|
elif input_ids is None:
|
||||||
decoder_input_ids = (
|
batch_size = flattened_patches.shape[0]
|
||||||
torch.LongTensor([[self.decoder_start_token_id]]).repeat(batch_size, 1).to(image_embeds.device)
|
input_ids = torch.LongTensor([[self.input_ids]]).repeat(batch_size, 1).to(input_ids.device)
|
||||||
)
|
|
||||||
|
|
||||||
if decoder_attention_mask is None:
|
if decoder_attention_mask is None:
|
||||||
decoder_attention_mask = torch.ones_like(decoder_input_ids).to(image_embeds.device)
|
decoder_attention_mask = torch.ones_like(input_ids).to(input_ids.device)
|
||||||
|
|
||||||
outputs = self.decoder.generate(
|
# cut decoder_input_ids if past is used
|
||||||
input_ids=decoder_input_ids,
|
if past_key_values is not None:
|
||||||
attention_mask=decoder_attention_mask,
|
input_ids = input_ids[:, -1:]
|
||||||
encoder_hidden_states=image_embeds,
|
|
||||||
encoder_attention_mask=attention_mask,
|
|
||||||
**generate_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return {
|
||||||
|
"flattened_patches": flattened_patches,
|
||||||
|
"decoder_input_ids": input_ids,
|
||||||
|
"past_key_values": past_key_values,
|
||||||
|
"encoder_outputs": encoder_outputs,
|
||||||
|
"attention_mask": attention_mask,
|
||||||
|
"decoder_attention_mask": decoder_attention_mask,
|
||||||
|
"head_mask": head_mask,
|
||||||
|
"decoder_head_mask": decoder_head_mask,
|
||||||
|
"cross_attn_head_mask": cross_attn_head_mask,
|
||||||
|
"use_cache": use_cache,
|
||||||
|
}
|
||||||
|
|||||||
@@ -443,24 +443,22 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
# signature.parameters is an OrderedDict => so arg_names order is deterministic
|
||||||
arg_names = [*signature.parameters.keys()]
|
arg_names = [*signature.parameters.keys()]
|
||||||
|
|
||||||
if model.config.is_encoder_decoder:
|
expected_arg_names = [
|
||||||
expected_arg_names = [
|
"flattened_patches",
|
||||||
"input_ids",
|
"attention_mask",
|
||||||
"attention_mask",
|
"decoder_input_ids",
|
||||||
"decoder_input_ids",
|
"decoder_attention_mask",
|
||||||
"decoder_attention_mask",
|
"head_mask",
|
||||||
]
|
"decoder_head_mask",
|
||||||
expected_arg_names.extend(
|
"cross_attn_head_mask",
|
||||||
["head_mask", "decoder_head_mask", "cross_attn_head_mask", "encoder_outputs"]
|
"encoder_outputs",
|
||||||
if "head_mask" and "decoder_head_mask" and "cross_attn_head_mask" in arg_names
|
"past_key_values",
|
||||||
else ["encoder_outputs"]
|
"labels",
|
||||||
)
|
"decoder_inputs_embeds",
|
||||||
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
"use_cache",
|
||||||
else:
|
]
|
||||||
expected_arg_names = (
|
|
||||||
["input_ids"] if model_class != Pix2StructForConditionalGeneration else ["flattened_patches"]
|
self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
|
||||||
)
|
|
||||||
self.assertListEqual(arg_names[:1], expected_arg_names)
|
|
||||||
|
|
||||||
def test_training(self):
|
def test_training(self):
|
||||||
if not self.model_tester.is_training:
|
if not self.model_tester.is_training:
|
||||||
@@ -765,7 +763,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_vqa_model(self):
|
def test_vqa_model(self):
|
||||||
model_id = "ybelkada/pix2struct-ai2d-base"
|
model_id = "google/pix2struct-ai2d-base"
|
||||||
|
|
||||||
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg"
|
||||||
image = Image.open(requests.get(image_url, stream=True).raw)
|
image = Image.open(requests.get(image_url, stream=True).raw)
|
||||||
@@ -784,7 +782,7 @@ class Pix2StructIntegrationTest(unittest.TestCase):
|
|||||||
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
|
self.assertEqual(processor.decode(predictions[0], skip_special_tokens=True), "ash cloud")
|
||||||
|
|
||||||
def test_vqa_model_batched(self):
|
def test_vqa_model_batched(self):
|
||||||
model_id = "ybelkada/pix2struct-ai2d-base"
|
model_id = "google/pix2struct-ai2d-base"
|
||||||
|
|
||||||
image_urls = [
|
image_urls = [
|
||||||
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
|
"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/ai2d-demo.jpg",
|
||||||
|
|||||||
@@ -306,6 +306,7 @@ src/transformers/models/pegasus/tokenization_pegasus.py
|
|||||||
src/transformers/models/pegasus/tokenization_pegasus_fast.py
|
src/transformers/models/pegasus/tokenization_pegasus_fast.py
|
||||||
src/transformers/models/perceiver/tokenization_perceiver.py
|
src/transformers/models/perceiver/tokenization_perceiver.py
|
||||||
src/transformers/models/phobert/tokenization_phobert.py
|
src/transformers/models/phobert/tokenization_phobert.py
|
||||||
|
src/transformers/models/pix2struct/modeling_pix2struct.py
|
||||||
src/transformers/models/plbart/tokenization_plbart.py
|
src/transformers/models/plbart/tokenization_plbart.py
|
||||||
src/transformers/models/prophetnet/tokenization_prophetnet.py
|
src/transformers/models/prophetnet/tokenization_prophetnet.py
|
||||||
src/transformers/models/rag/tokenization_rag.py
|
src/transformers/models/rag/tokenization_rag.py
|
||||||
|
|||||||
Reference in New Issue
Block a user