From 0e708178ed72e6b55aab047653dd7c703f02b95b Mon Sep 17 00:00:00 2001 From: NielsRogge <48327001+NielsRogge@users.noreply.github.com> Date: Mon, 27 Mar 2023 17:38:07 +0200 Subject: [PATCH] [Pix2Struct] Add support to resize embeddings (#22394) * First draft * Fix integration test * Remove script * Fix test and typos * Fix one more test * Skip tied embeddings test * Remove line * Address comments --- .../pix2struct/configuration_pix2struct.py | 42 +++---- .../models/pix2struct/modeling_pix2struct.py | 23 +++- .../pix2struct/test_modeling_pix2struct.py | 103 +++++++++++++++++- 3 files changed, 143 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/pix2struct/configuration_pix2struct.py b/src/transformers/models/pix2struct/configuration_pix2struct.py index c8a866966c..8642602cf9 100644 --- a/src/transformers/models/pix2struct/configuration_pix2struct.py +++ b/src/transformers/models/pix2struct/configuration_pix2struct.py @@ -35,17 +35,16 @@ class Pix2StructTextConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Pix2StructTextModel`]. It is used to instantiate a Pix2Struct text model according to the specified arguments, defining the model architecture. Instantiating a - configuration with the defaults will yield a similar configuration to that of the `Pix2StructText` used by the - [base architectures](https://huggingface.co/google/pix2struct-textcaps-base). + configuration with the defaults will yield a similar configuration to that of the Pix2Struct text decoder used by + the [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: vocab_size (`int`, *optional*, defaults to 50244): Vocabulary size of the `Pix2Struct` text model. Defines the number of different tokens that can be - represented by the `inputs_ids` passed when calling [`Pix2StructModel`]. + represented by the `inputs_ids` passed when calling [`Pix2StructTextModel`]. hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. d_kv (`int`, *optional*, defaults to 64): @@ -83,10 +82,10 @@ class Pix2StructTextConfig(PretrainedConfig): ```python >>> from transformers import Pix2StructTextConfig, Pix2StructTextModel - >>> # Initializing a Pix2StructTextConfig with Salesforce/pix2struct-vqa-base style configuration + >>> # Initializing a Pix2StructTextConfig with google/pix2struct-base style configuration >>> configuration = Pix2StructTextConfig() - >>> # Initializing a Pix2StructTextModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration + >>> # Initializing a Pix2StructTextModel (with random weights) from the google/pix2struct-base style configuration >>> model = Pix2StructTextModel(configuration) >>> # Accessing the model configuration @@ -118,6 +117,7 @@ class Pix2StructTextConfig(PretrainedConfig): use_cache=False, pad_token_id=0, eos_token_id=1, + tie_word_embeddings=False, **kwargs, ): self.vocab_size = vocab_size @@ -143,6 +143,7 @@ class Pix2StructTextConfig(PretrainedConfig): pad_token_id=pad_token_id, eos_token_id=eos_token_id, decoder_start_token_id=decoder_start_token_id, + tie_word_embeddings=tie_word_embeddings, **kwargs, ) @@ -168,14 +169,13 @@ class Pix2StructTextConfig(PretrainedConfig): class Pix2StructVisionConfig(PretrainedConfig): r""" This is the configuration class to store the configuration of a [`Pix2StructVisionModel`]. It is used to - instantiate a PIX2STRUCT vision model according to the specified arguments, defining the model architecture. + instantiate a Pix2Struct vision model according to the specified arguments, defining the model architecture. Instantiating a configuration defaults will yield a similar configuration to that of the Pix2Struct-base - [Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) architecture. + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. - Args: hidden_size (`int`, *optional*, defaults to 768): Dimensionality of the encoder layers and the pooler layer. @@ -223,10 +223,10 @@ class Pix2StructVisionConfig(PretrainedConfig): ```python >>> from transformers import Pix2StructVisionConfig, Pix2StructVisionModel - >>> # Initializing a Pix2StructVisionConfig with Salesforce/pix2struct-vqa-base style configuration + >>> # Initializing a Pix2StructVisionConfig with google/pix2struct-base style configuration >>> configuration = Pix2StructVisionConfig() - >>> # Initializing a Pix2StructVisionModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration + >>> # Initializing a Pix2StructVisionModel (with random weights) from the google/pix2struct-base style configuration >>> model = Pix2StructVisionModel(configuration) >>> # Accessing the model configuration @@ -301,11 +301,11 @@ class Pix2StructVisionConfig(PretrainedConfig): class Pix2StructConfig(PretrainedConfig): r""" - [`Pix2StructConfig`] is the configuration class to store the configuration of a [`Pix2StructModel`]. It is used to - instantiate a PIX2STRUCT model according to the specified arguments, defining the text model and vision model - configs. Instantiating a configuration with the defaults will yield a similar configuration to that of the - PIX2STRUCT-base [Salesforce/pix2struct-vqa-base](https://huggingface.co/Salesforce/pix2struct-vqa-base) - architecture. + [`Pix2StructConfig`] is the configuration class to store the configuration of a + [`Pix2StructForConditionalGeneration`]. It is used to instantiate a Pix2Struct model according to the specified + arguments, defining the text model and vision model configs. Instantiating a configuration with the defaults will + yield a similar configuration to that of the Pix2Struct-base + [google/pix2struct-base](https://huggingface.co/google/pix2struct-base) architecture. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the documentation from [`PretrainedConfig`] for more information. @@ -327,20 +327,20 @@ class Pix2StructConfig(PretrainedConfig): Example: ```python - >>> from transformers import Pix2StructConfig, Pix2StructModel + >>> from transformers import Pix2StructConfig, Pix2StructForConditionalGeneration - >>> # Initializing a Pix2StructConfig with Salesforce/pix2struct-vqa-base style configuration + >>> # Initializing a Pix2StructConfig with google/pix2struct-base style configuration >>> configuration = Pix2StructConfig() - >>> # Initializing a Pix2StructPModel (with random weights) from the Salesforce/pix2struct-vqa-base style configuration - >>> model = Pix2StructModel(configuration) + >>> # Initializing a Pix2StructForConditionalGeneration (with random weights) from the google/pix2struct-base style configuration + >>> model = Pix2StructForConditionalGeneration(configuration) >>> # Accessing the model configuration >>> configuration = model.config >>> # We can also initialize a Pix2StructConfig from a Pix2StructTextConfig and a Pix2StructVisionConfig - >>> # Initializing a PIX2STRUCTText and PIX2STRUCTVision configuration + >>> # Initializing a Pix2Struct text and Pix2Struct vision configuration >>> config_text = Pix2StructTextConfig() >>> config_vision = Pix2StructVisionConfig() diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 1dcd22e0f4..ead913e1df 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -1369,6 +1369,12 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): def set_input_embeddings(self, new_embeddings): self.embed_tokens = new_embeddings + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + @add_start_docstrings_to_model_forward(PIX2STRUCT_TEXT_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC) def forward( @@ -1626,12 +1632,25 @@ class Pix2StructForConditionalGeneration(Pix2StructPreTrainedModel): self.post_init() def get_input_embeddings(self): - return self.shared + return self.decoder.get_input_embeddings() def set_input_embeddings(self, new_embeddings): - self.shared = new_embeddings self.decoder.set_input_embeddings(new_embeddings) + def get_output_embeddings(self) -> nn.Module: + return self.decoder.get_output_embeddings() + + def set_output_embeddings(self, new_embeddings): + self.decoder.set_output_embeddings(new_embeddings) + + def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding: + model_embeds = self.decoder.resize_token_embeddings(new_num_tokens) + + # update vocab size + self.config.text_config.vocab_size = new_num_tokens + + return model_embeds + def get_decoder(self): return self.decoder diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 3896f8d845..dc219bbd61 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -14,7 +14,7 @@ # limitations under the License. """ Testing suite for the PyTorch Pix2Struct model. """ - +import copy import inspect import os import tempfile @@ -396,7 +396,7 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): fx_compatible = False test_head_masking = False test_pruning = False - test_resize_embeddings = False + test_resize_embeddings = True test_attention_outputs = False test_torchscript = False @@ -526,6 +526,105 @@ class Pix2StructTextImageModelTest(ModelTesterMixin, unittest.TestCase): msg=f"Parameter {name} of model {model_class} seems not properly initialized", ) + # overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig` + def test_resize_tokens_embeddings(self): + original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + return + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config) + model.to(torch_device) + + if self.model_tester.is_training is False: + model.eval() + + model_vocab_size = config.text_config.vocab_size + # Retrieve the embeddings and clone theme + model_embed = model.resize_token_embeddings(model_vocab_size) + cloned_embeddings = model_embed.weight.clone() + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model_embed = model.resize_token_embeddings(model_vocab_size - 15) + self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + self.assertEqual(model_embed.weight.shape[0], cloned_embeddings.shape[0] - 15) + + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Decoder input ids should be clamped to the maximum size of the vocabulary + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that adding and removing tokens has not modified the first part of the embedding matrix. + models_equal = True + for p1, p2 in zip(cloned_embeddings, model_embed.weight): + if p1.data.ne(p2.data).sum() > 0: + models_equal = False + + self.assertTrue(models_equal) + + # overwrite because `vocab_size` is not an attribute of `Pix2StructConfig` but rather `Pix2StructTextConfig` + def test_resize_embeddings_untied(self): + original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if not self.test_resize_embeddings: + return + + original_config.tie_word_embeddings = False + + # if model cannot untied embeddings -> leave test + if original_config.tie_word_embeddings: + return + + for model_class in self.all_model_classes: + config = copy.deepcopy(original_config) + model = model_class(config).to(torch_device) + + # if no output embeddings -> leave test + if model.get_output_embeddings() is None: + continue + + # Check that resizing the token embeddings with a larger vocab size increases the model's vocab size + model_vocab_size = config.text_config.vocab_size + model.resize_token_embeddings(model_vocab_size + 10) + self.assertEqual(model.config.text_config.vocab_size, model_vocab_size + 10) + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size + 10) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size + 10) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + # Check that resizing the token embeddings with a smaller vocab size decreases the model's vocab size + model.resize_token_embeddings(model_vocab_size - 15) + self.assertEqual(model.config.text_config.vocab_size, model_vocab_size - 15) + # Check that it actually resizes the embeddings matrix + output_embeds = model.get_output_embeddings() + self.assertEqual(output_embeds.weight.shape[0], model_vocab_size - 15) + # Check bias if present + if output_embeds.bias is not None: + self.assertEqual(output_embeds.bias.shape[0], model_vocab_size - 15) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + # Decoder input ids should be clamped to the maximum size of the vocabulary + if "decoder_input_ids" in inputs_dict: + inputs_dict["decoder_input_ids"].clamp_(max=model_vocab_size - 15 - 1) + # Check that the model can still do a forward pass successfully (every parameter should be resized) + model(**self._prepare_for_class(inputs_dict, model_class)) + + @unittest.skip(reason="Pix2Struct doesn't use tied weights") + def test_tied_model_weights_key_ignore(self): + pass + def _create_and_check_torchscript(self, config, inputs_dict): if not self.test_torchscript: return