From 33f36c869fa5db07bb35789d411cfed8bf9d3b0c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 20 Dec 2021 11:19:08 -0500 Subject: [PATCH] Add a main_input_name attribute to all models (#14803) * Add a main_input_name attribute to all models * Fix tests * Wtf Vs Code? * Update src/transformers/models/imagegpt/modeling_imagegpt.py Co-authored-by: Patrick von Platen * Style * Fix copies Co-authored-by: Patrick von Platen --- src/transformers/modeling_flax_utils.py | 3 +++ src/transformers/modeling_tf_utils.py | 4 ++++ src/transformers/modeling_utils.py | 12 +++++++----- src/transformers/models/beit/modeling_beit.py | 1 + src/transformers/models/beit/modeling_flax_beit.py | 1 + src/transformers/models/clip/modeling_clip.py | 1 + src/transformers/models/clip/modeling_flax_clip.py | 1 + src/transformers/models/deit/modeling_deit.py | 1 + src/transformers/models/detr/modeling_detr.py | 1 + src/transformers/models/hubert/modeling_hubert.py | 1 + src/transformers/models/hubert/modeling_tf_hubert.py | 1 + .../models/imagegpt/modeling_imagegpt.py | 1 + .../models/perceiver/modeling_perceiver.py | 1 + .../models/segformer/modeling_segformer.py | 1 + src/transformers/models/sew/modeling_sew.py | 1 + src/transformers/models/sew_d/modeling_sew_d.py | 1 + .../modeling_speech_encoder_decoder.py | 1 + .../models/speech_to_text/modeling_speech_to_text.py | 1 + .../models/unispeech/modeling_unispeech.py | 1 + .../models/unispeech_sat/modeling_unispeech_sat.py | 1 + .../modeling_flax_vision_encoder_decoder.py | 1 + .../modeling_vision_encoder_decoder.py | 1 + src/transformers/models/vit/modeling_flax_vit.py | 1 + src/transformers/models/vit/modeling_tf_vit.py | 1 + src/transformers/models/vit/modeling_vit.py | 1 + .../models/wav2vec2/modeling_flax_wav2vec2.py | 1 + .../models/wav2vec2/modeling_tf_wav2vec2.py | 1 + .../models/wav2vec2/modeling_wav2vec2.py | 1 + src/transformers/models/wavlm/modeling_wavlm.py | 1 + tests/test_modeling_common.py | 7 +++++++ tests/test_modeling_flax_common.py | 7 +++++++ tests/test_modeling_tf_common.py | 7 +++++++ 32 files changed, 61 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 80f4e549c2..2be53474c3 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -76,9 +76,12 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for + NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models). """ config_class = None base_model_prefix = "" + main_input_name = "input_ids" def __init__( self, diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index df3fcc4639..bb66e3f62f 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -653,9 +653,13 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. + - **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for + NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models). """ config_class = None base_model_prefix = "" + main_input_name = "input_ids" + # a list of re pattern of tensor names to ignore from the model when loading the model weights # (and avoid unnecessary warnings). _keys_to_ignore_on_load_missing = None diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 177f0cb79a..2cc37a6f94 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -17,7 +17,6 @@ import inspect import os import re -import warnings from contextlib import contextmanager from dataclasses import dataclass from functools import partial @@ -376,11 +375,10 @@ class ModuleUtilsMixin: Returns: :obj:`int`: The total number of tokens. """ - token_inputs = [tensor for key, tensor in input_dict.items() if "input" in key] - if token_inputs: - return sum([token_input.numel() for token_input in token_inputs]) + if self.main_input_name in input_dict: + return input_dict[self.main_input_name].numel() else: - warnings.warn( + logger.warn( "Could not estimate the number of tokens of the input, floating-point operations will not be computed" ) return 0 @@ -438,9 +436,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - **is_parallelizable** (:obj:`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (:obj:`str`) -- The name of the principal input to the model (often :obj:`input_ids` for + NLP models, :obj:`pixel_values` for vision models and :obj:`input_values` for speech models). """ config_class = None base_model_prefix = "" + main_input_name = "input_ids" + # a list of re pattern of tensor names to ignore from the model when loading the model weights # (and avoid unnecessary warnings). _keys_to_ignore_on_load_missing = None diff --git a/src/transformers/models/beit/modeling_beit.py b/src/transformers/models/beit/modeling_beit.py index 12f0050912..5ff43d4c56 100755 --- a/src/transformers/models/beit/modeling_beit.py +++ b/src/transformers/models/beit/modeling_beit.py @@ -523,6 +523,7 @@ class BeitPreTrainedModel(PreTrainedModel): config_class = BeitConfig base_model_prefix = "beit" + main_input_name = "pixel_values" supports_gradient_checkpointing = True def _init_weights(self, module): diff --git a/src/transformers/models/beit/modeling_flax_beit.py b/src/transformers/models/beit/modeling_flax_beit.py index 065dd0519a..e276b34fb8 100644 --- a/src/transformers/models/beit/modeling_flax_beit.py +++ b/src/transformers/models/beit/modeling_flax_beit.py @@ -590,6 +590,7 @@ class FlaxBeitPreTrainedModel(FlaxPreTrainedModel): config_class = BeitConfig base_model_prefix = "beit" + main_input_name = "pixel_values" module_class: nn.Module = None def __init__(self, config: BeitConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): diff --git a/src/transformers/models/clip/modeling_clip.py b/src/transformers/models/clip/modeling_clip.py index 89f885ef71..aa3f724b93 100755 --- a/src/transformers/models/clip/modeling_clip.py +++ b/src/transformers/models/clip/modeling_clip.py @@ -789,6 +789,7 @@ class CLIPVisionTransformer(nn.Module): class CLIPVisionModel(CLIPPreTrainedModel): config_class = CLIPVisionConfig + main_input_name = "pixel_values" def __init__(self, config: CLIPVisionConfig): super().__init__(config) diff --git a/src/transformers/models/clip/modeling_flax_clip.py b/src/transformers/models/clip/modeling_flax_clip.py index ab20758d7d..2a088f0f02 100644 --- a/src/transformers/models/clip/modeling_flax_clip.py +++ b/src/transformers/models/clip/modeling_flax_clip.py @@ -653,6 +653,7 @@ class FlaxCLIPTextPreTrainedModel(FlaxPreTrainedModel): class FlaxCLIPVisionPreTrainedModel(FlaxPreTrainedModel): config_class = CLIPVisionConfig + main_input_name = "pixel_values" module_class: nn.Module = None def __init__( diff --git a/src/transformers/models/deit/modeling_deit.py b/src/transformers/models/deit/modeling_deit.py index 3f26cbef31..6698b5f77f 100644 --- a/src/transformers/models/deit/modeling_deit.py +++ b/src/transformers/models/deit/modeling_deit.py @@ -385,6 +385,7 @@ class DeiTPreTrainedModel(PreTrainedModel): config_class = DeiTConfig base_model_prefix = "deit" + main_input_name = "pixel_values" supports_gradient_checkpointing = True def _init_weights(self, module): diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index e7771a4adb..fbaa46e17f 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -784,6 +784,7 @@ class DetrClassificationHead(nn.Module): class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" + main_input_name = "pixel_values" def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/hubert/modeling_hubert.py b/src/transformers/models/hubert/modeling_hubert.py index a26066f2f5..52d66831e8 100755 --- a/src/transformers/models/hubert/modeling_hubert.py +++ b/src/transformers/models/hubert/modeling_hubert.py @@ -776,6 +776,7 @@ class HubertPreTrainedModel(PreTrainedModel): config_class = HubertConfig base_model_prefix = "hubert" + main_input_name = "input_values" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index aa7f659608..d08b747b70 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -1265,6 +1265,7 @@ class TFHubertPreTrainedModel(TFPreTrainedModel): config_class = HubertConfig base_model_prefix = "hubert" + main_input_name = "input_values" @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: diff --git a/src/transformers/models/imagegpt/modeling_imagegpt.py b/src/transformers/models/imagegpt/modeling_imagegpt.py index 4c90e8148c..054054df80 100755 --- a/src/transformers/models/imagegpt/modeling_imagegpt.py +++ b/src/transformers/models/imagegpt/modeling_imagegpt.py @@ -496,6 +496,7 @@ class ImageGPTPreTrainedModel(PreTrainedModel): config_class = ImageGPTConfig load_tf_weights = load_tf_weights_in_imagegpt base_model_prefix = "transformer" + main_input_name = "input_ids" supports_gradient_checkpointing = True def __init__(self, *inputs, **kwargs): diff --git a/src/transformers/models/perceiver/modeling_perceiver.py b/src/transformers/models/perceiver/modeling_perceiver.py index d3f91c34df..e231fa8877 100755 --- a/src/transformers/models/perceiver/modeling_perceiver.py +++ b/src/transformers/models/perceiver/modeling_perceiver.py @@ -619,6 +619,7 @@ class PerceiverPreTrainedModel(PreTrainedModel): config_class = PerceiverConfig base_model_prefix = "perceiver" + main_input_name = "inputs" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/segformer/modeling_segformer.py b/src/transformers/models/segformer/modeling_segformer.py index dc358ae15b..7cfbe0ceb7 100755 --- a/src/transformers/models/segformer/modeling_segformer.py +++ b/src/transformers/models/segformer/modeling_segformer.py @@ -406,6 +406,7 @@ class SegformerPreTrainedModel(PreTrainedModel): config_class = SegformerConfig base_model_prefix = "segformer" + main_input_name = "pixel_values" def _init_weights(self, module): """Initialize the weights""" diff --git a/src/transformers/models/sew/modeling_sew.py b/src/transformers/models/sew/modeling_sew.py index 5cecf18f2e..92ac0d1e0f 100644 --- a/src/transformers/models/sew/modeling_sew.py +++ b/src/transformers/models/sew/modeling_sew.py @@ -675,6 +675,7 @@ class SEWPreTrainedModel(PreTrainedModel): config_class = SEWConfig base_model_prefix = "sew" + main_input_name = "input_values" supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] diff --git a/src/transformers/models/sew_d/modeling_sew_d.py b/src/transformers/models/sew_d/modeling_sew_d.py index e4e1bb0f0b..27971624cf 100644 --- a/src/transformers/models/sew_d/modeling_sew_d.py +++ b/src/transformers/models/sew_d/modeling_sew_d.py @@ -1201,6 +1201,7 @@ class SEWDPreTrainedModel(PreTrainedModel): config_class = SEWDConfig base_model_prefix = "sew-d" + main_input_name = "input_values" _keys_to_ignore_on_load_missing = [r"position_ids"] supports_gradient_checkpointing = True diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index 75c939d906..87041f5e24 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -180,6 +180,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): """ config_class = SpeechEncoderDecoderConfig base_model_prefix = "speech_encoder_decoder" + main_input_name = "input_values" def __init__( self, diff --git a/src/transformers/models/speech_to_text/modeling_speech_to_text.py b/src/transformers/models/speech_to_text/modeling_speech_to_text.py index aead484a59..4fade8ba7c 100755 --- a/src/transformers/models/speech_to_text/modeling_speech_to_text.py +++ b/src/transformers/models/speech_to_text/modeling_speech_to_text.py @@ -539,6 +539,7 @@ class Speech2TextDecoderLayer(nn.Module): class Speech2TextPreTrainedModel(PreTrainedModel): config_class = Speech2TextConfig base_model_prefix = "model" + main_input_name = "input_features" supports_gradient_checkpointing = True def _init_weights(self, module): diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index cfacf721b0..772fac5dca 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -912,6 +912,7 @@ class UniSpeechPreTrainedModel(PreTrainedModel): config_class = UniSpeechConfig base_model_prefix = "unispeech" + main_input_name = "input_values" _keys_to_ignore_on_load_missing = [r"position_ids"] supports_gradient_checkpointing = True diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 7b7009b329..4d0ef05e06 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -947,6 +947,7 @@ class UniSpeechSatPreTrainedModel(PreTrainedModel): config_class = UniSpeechSatConfig base_model_prefix = "unispeech_sat" + main_input_name = "input_values" _keys_to_ignore_on_load_missing = [r"position_ids"] supports_gradient_checkpointing = True diff --git a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py index 85d1aa732b..54813f482e 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_flax_vision_encoder_decoder.py @@ -283,6 +283,7 @@ class FlaxVisionEncoderDecoderModel(FlaxPreTrainedModel): """ config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" module_class = FlaxVisionEncoderDecoderModule def __init__( diff --git a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py index 75c1bcd3b0..818e90ea30 100644 --- a/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +++ b/src/transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py @@ -160,6 +160,7 @@ class VisionEncoderDecoderModel(PreTrainedModel): """ config_class = VisionEncoderDecoderConfig base_model_prefix = "vision_encoder_decoder" + main_input_name = "pixel_values" def __init__( self, diff --git a/src/transformers/models/vit/modeling_flax_vit.py b/src/transformers/models/vit/modeling_flax_vit.py index 32252d5551..04df178381 100644 --- a/src/transformers/models/vit/modeling_flax_vit.py +++ b/src/transformers/models/vit/modeling_flax_vit.py @@ -406,6 +406,7 @@ class FlaxViTPreTrainedModel(FlaxPreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + main_input_name = "pixel_values" module_class: nn.Module = None def __init__(self, config: ViTConfig, input_shape=None, seed: int = 0, dtype: jnp.dtype = jnp.float32, **kwargs): diff --git a/src/transformers/models/vit/modeling_tf_vit.py b/src/transformers/models/vit/modeling_tf_vit.py index 8a809bfe8a..32bfc9a6b8 100644 --- a/src/transformers/models/vit/modeling_tf_vit.py +++ b/src/transformers/models/vit/modeling_tf_vit.py @@ -555,6 +555,7 @@ class TFViTPreTrainedModel(TFPreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + main_input_name = "pixel_values" @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: diff --git a/src/transformers/models/vit/modeling_vit.py b/src/transformers/models/vit/modeling_vit.py index 1a8a0db513..92915ee75d 100644 --- a/src/transformers/models/vit/modeling_vit.py +++ b/src/transformers/models/vit/modeling_vit.py @@ -412,6 +412,7 @@ class ViTPreTrainedModel(PreTrainedModel): config_class = ViTConfig base_model_prefix = "vit" + main_input_name = "pixel_values" supports_gradient_checkpointing = True def _init_weights(self, module): diff --git a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py index 8105a30c14..a62c88ec12 100644 --- a/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_flax_wav2vec2.py @@ -775,6 +775,7 @@ class FlaxWav2Vec2PreTrainedModel(FlaxPreTrainedModel): config_class = Wav2Vec2Config base_model_prefix: str = "wav2vec2" + main_input_name = "input_values" module_class: nn.Module = None def __init__( diff --git a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py index b1599b9e42..c758bf6187 100644 --- a/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py @@ -1256,6 +1256,7 @@ class TFWav2Vec2PreTrainedModel(TFPreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" + main_input_name = "input_values" @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index 046a04b2db..bd7550e773 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -1044,6 +1044,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): config_class = Wav2Vec2Config base_model_prefix = "wav2vec2" + main_input_name = "input_values" _keys_to_ignore_on_load_missing = [r"position_ids"] supports_gradient_checkpointing = True diff --git a/src/transformers/models/wavlm/modeling_wavlm.py b/src/transformers/models/wavlm/modeling_wavlm.py index d16472beb6..8c89acabf5 100755 --- a/src/transformers/models/wavlm/modeling_wavlm.py +++ b/src/transformers/models/wavlm/modeling_wavlm.py @@ -996,6 +996,7 @@ class WavLMPreTrainedModel(PreTrainedModel): config_class = WavLMConfig base_model_prefix = "wavlm" + main_input_name = "input_values" _keys_to_ignore_on_load_missing = [r"position_ids"] supports_gradient_checkpointing = True diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6640028293..1df5e9e0f0 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1315,6 +1315,13 @@ class ModelTesterMixin: x = model.get_output_embeddings() self.assertTrue(x is None or isinstance(x, nn.Linear)) + def test_model_main_input_name(self): + for model_class in self.all_model_classes: + model_signature = inspect.signature(getattr(model_class, "forward")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(model_class.main_input_name, observed_main_input_name) + def test_correct_missing_keys(self): if not self.test_missing_keys: return diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index f4bcff71f1..b1d15b6673 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -778,6 +778,13 @@ class FlaxModelTesterMixin: for name, type_ in types.items(): self.assertEqual(type_, jnp.bfloat16, msg=f"param {name} is not in bf16.") + def test_model_main_input_name(self): + for model_class in self.all_model_classes: + model_signature = inspect.signature(getattr(model_class, "__call__")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(model_class.main_input_name, observed_main_input_name) + def test_headmasking(self): if not self.test_head_masking: return diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8ad7ec4472..fd7a6f9069 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -1183,6 +1183,13 @@ class TFModelTesterMixin: else: new_model_without_prefix(input_ids) + def test_model_main_input_name(self): + for model_class in self.all_model_classes: + model_signature = inspect.signature(getattr(model_class, "call")) + # The main input is the name of the argument after `self` + observed_main_input_name = list(model_signature.parameters.keys())[1] + self.assertEqual(model_class.main_input_name, observed_main_input_name) + def _generate_random_bad_tokens(self, num_bad_tokens, model): # special tokens cannot be bad tokens special_tokens = []