[Pix2Struct] Add support to resize embeddings (#22394)
* First draft * Fix integration test * Remove script * Fix test and typos * Fix one more test * Skip tied embeddings test * Remove line * Address comments
This commit is contained in:
@@ -35,17 +35,16 @@ class Pix2StructTextConfig(PretrainedConfig):
|
|||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate
|
This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate
|
||||||
a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a
|
a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a
|
||||||
configuration with the defaults will yield a similar configuration to that of the `Pix2StructText` used by the
|
configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by
|
||||||
[base architectures](https://huggingface.co/google/pix2struct-textcaps-base).
|
the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vocab_size (`int`, *optional*, defaults to 50244):
|
vocab_size (`int`, *optional*, defaults to 50244):
|
||||||
Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be
|
Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be
|
||||||
represented by the `inputs_ids` passed when calling [`Pix2StructModel`].
|
represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`].
|
||||||
hidden_size (`int`, *optional*, defaults to 768):
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
Dimensionality of the encoder layers and the pooler layer.
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
d_kv (`int`, *optional*, defaults to 64):
|
d_kv (`int`, *optional*, defaults to 64):
|
||||||
@@ -83,10 +82,10 @@ class Pix2StructTextConfig(PretrainedConfig):
|
|||||||
```python
|
```python
|
||||||
>>> from transformers import Pix2StructTextConfig, Pix2StructTextModel
|
>>> from transformers import Pix2StructTextConfig, Pix2StructTextModel
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructTextConfig with Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration
|
||||||
>>> configuration = Pix2StructTextConfig()
|
>>> configuration = Pix2StructTextConfig()
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructTextModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration
|
||||||
>>> model = Pix2StructTextModel(configuration)
|
>>> model = Pix2StructTextModel(configuration)
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
@@ -118,6 +117,7 @@ class Pix2StructTextConfig(PretrainedConfig):
|
|||||||
use_cache=False,
|
use_cache=False,
|
||||||
pad_token_id=0,
|
pad_token_id=0,
|
||||||
eos_token_id=1,
|
eos_token_id=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
@@ -143,6 +143,7 @@ class Pix2StructTextConfig(PretrainedConfig):
|
|||||||
pad_token_id=pad_token_id,
|
pad_token_id=pad_token_id,
|
||||||
eos_token_id=eos_token_id,
|
eos_token_id=eos_token_id,
|
||||||
decoder_start_token_id=decoder_start_token_id,
|
decoder_start_token_id=decoder_start_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -168,14 +169,13 @@ class Pix2StructTextConfig(PretrainedConfig):
|
|||||||
class Pix2StructVisionConfig(PretrainedConfig):
|
class Pix2StructVisionConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to
|
This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to
|
||||||
instantiate a PIX2STRUCT vision model according to the specified arguments, defining the model architecture.
|
instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture.
|
||||||
Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base
|
Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base
|
||||||
[Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) architecture.
|
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hidden_size (`int`, *optional*, defaults to 768):
|
hidden_size (`int`, *optional*, defaults to 768):
|
||||||
Dimensionality of the encoder layers and the pooler layer.
|
Dimensionality of the encoder layers and the pooler layer.
|
||||||
@@ -223,10 +223,10 @@ class Pix2StructVisionConfig(PretrainedConfig):
|
|||||||
```python
|
```python
|
||||||
>>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel
|
>>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructVisionConfig with Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration
|
||||||
>>> configuration = Pix2StructVisionConfig()
|
>>> configuration = Pix2StructVisionConfig()
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructVisionModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration
|
||||||
>>> model = Pix2StructVisionModel(configuration)
|
>>> model = Pix2StructVisionModel(configuration)
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
@@ -301,11 +301,11 @@ class Pix2StructVisionConfig(PretrainedConfig):
|
|||||||
|
|
||||||
class Pix2StructConfig(PretrainedConfig):
|
class Pix2StructConfig(PretrainedConfig):
|
||||||
r"""
|
r"""
|
||||||
[`Pix2StructConfig`] is the configuration class to store the configuration of a [`Pix2StructModel`]. It is used to
|
[`Pix2StructConfig`] is the configuration class to store the configuration of a
|
||||||
instantiate a PIX2STRUCT model according to the specified arguments, defining the text model and vision model
|
[`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified
|
||||||
configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the
|
arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will
|
||||||
PIX2STRUCT-base [Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base)
|
yield a similar configuration to that of the Pix2Struct-base
|
||||||
architecture.
|
[google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture.
|
||||||
|
|
||||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
documentation from [`PretrainedConfig`] for more information.
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
@@ -327,20 +327,20 @@ class Pix2StructConfig(PretrainedConfig):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
>>> from transformers import Pix2StructConfig, Pix2StructModel
|
>>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructConfig with Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration
|
||||||
>>> configuration = Pix2StructConfig()
|
>>> configuration = Pix2StructConfig()
|
||||||
|
|
||||||
>>> # Initializing a Pix2StructPModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration
|
>>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration
|
||||||
>>> model = Pix2StructModel(configuration)
|
>>> model = Pix2StructForConditionalGeneration(configuration)
|
||||||
|
|
||||||
>>> # Accessing the model configuration
|
>>> # Accessing the model configuration
|
||||||
>>> configuration = model.config
|
>>> configuration = model.config
|
||||||
|
|
||||||
>>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig
|
>>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig
|
||||||
|
|
||||||
>>> # Initializing a PIX2STRUCTText and PIX2STRUCTVision configuration
|
>>> # Initializing a Pix2Struct text and Pix2Struct vision configuration
|
||||||
>>> config_text = Pix2StructTextConfig()
|
>>> config_text = Pix2StructTextConfig()
|
||||||
>>> config_vision = Pix2StructVisionConfig()
|
>>> config_vision = Pix2StructVisionConfig()
|
||||||
|
|
||||||
|
|||||||
@@ -1369,6 +1369,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel):
|
|||||||
def set_input_embeddings(self, new_embeddings):
|
def set_input_embeddings(self, new_embeddings):
|
||||||
self.embed_tokens = new_embeddings
|
self.embed_tokens = new_embeddings
|
||||||
|
|
||||||
|
def get_output_embeddings(self):
|
||||||
|
return self.lm_head
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1626,12 +1632,25 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel):
|
|||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.shared
|
return self.decoder.get_input_embeddings()
|
||||||
|
|
||||||
def set_input_embeddings(self, new_embeddings):
|
def set_input_embeddings(self, new_embeddings):
|
||||||
self.shared = new_embeddings
|
|
||||||
self.decoder.set_input_embeddings(new_embeddings)
|
self.decoder.set_input_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def get_output_embeddings(self) -> nn.Module:
|
||||||
|
return self.decoder.get_output_embeddings()
|
||||||
|
|
||||||
|
def set_output_embeddings(self, new_embeddings):
|
||||||
|
self.decoder.set_output_embeddings(new_embeddings)
|
||||||
|
|
||||||
|
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
|
||||||
|
model_embeds = self.decoder.resize_token_embeddings(new_num_tokens)
|
||||||
|
|
||||||
|
# update vocab size
|
||||||
|
self.config.text_config.vocab_size = new_num_tokens
|
||||||
|
|
||||||
|
return model_embeds
|
||||||
|
|
||||||
def get_decoder(self):
|
def get_decoder(self):
|
||||||
return self.decoder
|
return self.decoder
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
""" Testing suite for the PyTorch Pix2Struct model. """
|
""" Testing suite for the PyTorch Pix2Struct model. """
|
||||||
|
|
||||||
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -396,7 +396,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
fx_compatible = False
|
fx_compatible = False
|
||||||
test_head_masking = False
|
test_head_masking = False
|
||||||
test_pruning = False
|
test_pruning = False
|
||||||
test_resize_embeddings = False
|
test_resize_embeddings = True
|
||||||
test_attention_outputs = False
|
test_attention_outputs = False
|
||||||
test_torchscript = False
|
test_torchscript = False
|
||||||
|
|
||||||
@@ -526,6 +526,105 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
|
||||||
|
def test_resize_tokens_embeddings(self):
|
||||||
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if not self.test_resize_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
if self.model_tester.is_training is False:
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
# Retrieve the embeddings and clone theme
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size)
|
||||||
|
cloned_embeddings = model_embed.weight.clone()
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model_embed = model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15)
|
||||||
|
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Decoder input ids should be clamped to the maximum size of the vocabulary
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that adding and removing tokens has not modified the first part of the embedding matrix.
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(cloned_embeddings, model_embed.weight):
|
||||||
|
if p1.data.ne(p2.data).sum() > 0:
|
||||||
|
models_equal = False
|
||||||
|
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig`
|
||||||
|
def test_resize_embeddings_untied(self):
|
||||||
|
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
if not self.test_resize_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
original_config.tie_word_embeddings = False
|
||||||
|
|
||||||
|
# if model cannot untied embeddings -> leave test
|
||||||
|
if original_config.tie_word_embeddings:
|
||||||
|
return
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
config = copy.deepcopy(original_config)
|
||||||
|
model = model_class(config).to(torch_device)
|
||||||
|
|
||||||
|
# if no output embeddings -> leave test
|
||||||
|
if model.get_output_embeddings() is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a larger vocab size increases the model's vocab size
|
||||||
|
model_vocab_size = config.text_config.vocab_size
|
||||||
|
model.resize_token_embeddings(model_vocab_size + 10)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10)
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
# Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size
|
||||||
|
model.resize_token_embeddings(model_vocab_size - 15)
|
||||||
|
self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15)
|
||||||
|
# Check that it actually resizes the embeddings matrix
|
||||||
|
output_embeds = model.get_output_embeddings()
|
||||||
|
self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15)
|
||||||
|
# Check bias if present
|
||||||
|
if output_embeds.bias is not None:
|
||||||
|
self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
# Decoder input ids should be clamped to the maximum size of the vocabulary
|
||||||
|
if "decoder_input_ids" in inputs_dict:
|
||||||
|
inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1)
|
||||||
|
# Check that the model can still do a forward pass successfully (every parameter should be resized)
|
||||||
|
model(**self._prepare_for_class(inputs_dict, model_class))
|
||||||
|
|
||||||
|
@unittest.skip(reason="Pix2Struct doesn't use tied weights")
|
||||||
|
def test_tied_model_weights_key_ignore(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def _create_and_check_torchscript(self, config, inputs_dict):
|
def _create_and_check_torchscript(self, config, inputs_dict):
|
||||||
if not self.test_torchscript:
|
if not self.test_torchscript:
|
||||||
return
|
return
|
||||||
|
|||||||
Reference in New Issue
Block a user