From 9eda6b52e20d80cf165224d996babc67f913017e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Wed, 23 Jun 2021 10:40:54 -0400 Subject: [PATCH] Add all XxxPreTrainedModel to the main init (#12314) * Add all XxxPreTrainedModel to the main init * Add to template * Add to template bis * Add FlaxT5 --- src/transformers/__init__.py | 121 +++++++++++--- src/transformers/models/bart/__init__.py | 2 + .../models/bert_generation/__init__.py | 2 + .../models/blenderbot/__init__.py | 12 +- .../models/blenderbot_small/__init__.py | 7 +- src/transformers/models/clip/__init__.py | 11 +- src/transformers/models/flaubert/__init__.py | 2 + src/transformers/models/funnel/__init__.py | 4 + src/transformers/models/gpt2/__init__.py | 4 +- src/transformers/models/layoutlm/__init__.py | 2 + .../models/longformer/__init__.py | 4 + src/transformers/models/marian/__init__.py | 4 +- src/transformers/models/mbart/__init__.py | 8 +- .../models/megatron_bert/__init__.py | 2 + src/transformers/models/pegasus/__init__.py | 8 +- src/transformers/models/rag/__init__.py | 23 ++- src/transformers/models/reformer/__init__.py | 2 + src/transformers/models/roberta/__init__.py | 2 + src/transformers/models/tapas/__init__.py | 2 + src/transformers/models/vit/__init__.py | 8 +- src/transformers/utils/dummy_flax_objects.py | 54 +++++++ src/transformers/utils/dummy_pt_objects.py | 153 ++++++++++++++++++ src/transformers/utils/dummy_tf_objects.py | 81 ++++++++++ .../utils/dummy_timm_and_vision_objects.py | 9 ++ ...ce_{{cookiecutter.lowercase_modelname}}.py | 2 + utils/check_repo.py | 54 +++++-- 26 files changed, 532 insertions(+), 51 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 08068c5bec..f1c0383336 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available(): "DetrForObjectDetection", "DetrForSegmentation", "DetrModel", + "DetrPreTrainedModel", ] ) else: @@ -570,6 +571,7 @@ if is_torch_available(): [ "BertGenerationDecoder", "BertGenerationEncoder", + "BertGenerationPreTrainedModel", "load_tf_weights_in_bert_generation", ] ) @@ -597,6 +599,7 @@ if is_torch_available(): "BigBirdPegasusForQuestionAnswering", "BigBirdPegasusForSequenceClassification", "BigBirdPegasusModel", + "BigBirdPegasusPreTrainedModel", ] ) _import_structure["models.blenderbot"].extend( @@ -605,6 +608,7 @@ if is_torch_available(): "BlenderbotForCausalLM", "BlenderbotForConditionalGeneration", "BlenderbotModel", + "BlenderbotPreTrainedModel", ] ) _import_structure["models.blenderbot_small"].extend( @@ -613,6 +617,7 @@ if is_torch_available(): "BlenderbotSmallForCausalLM", "BlenderbotSmallForConditionalGeneration", "BlenderbotSmallModel", + "BlenderbotSmallPreTrainedModel", ] ) _import_structure["models.camembert"].extend( @@ -754,6 +759,7 @@ if is_torch_available(): "FunnelForSequenceClassification", "FunnelForTokenClassification", "FunnelModel", + "FunnelPreTrainedModel", "load_tf_weights_in_funnel", ] ) @@ -805,6 +811,7 @@ if is_torch_available(): "LayoutLMForSequenceClassification", "LayoutLMForTokenClassification", "LayoutLMModel", + "LayoutLMPreTrainedModel", ] ) _import_structure["models.led"].extend( @@ -814,6 +821,7 @@ if is_torch_available(): "LEDForQuestionAnswering", "LEDForSequenceClassification", "LEDModel", + "LEDPreTrainedModel", ] ) _import_structure["models.longformer"].extend( @@ -825,6 +833,7 @@ if is_torch_available(): "LongformerForSequenceClassification", "LongformerForTokenClassification", "LongformerModel", + "LongformerPreTrainedModel", "LongformerSelfAttention", ] ) @@ -854,6 +863,7 @@ if is_torch_available(): "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", "M2M100ForConditionalGeneration", "M2M100Model", + "M2M100PreTrainedModel", ] ) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) @@ -864,6 +874,7 @@ if is_torch_available(): "MBartForQuestionAnswering", "MBartForSequenceClassification", "MBartModel", + "MBartPreTrainedModel", ] ) _import_structure["models.megatron_bert"].extend( @@ -878,6 +889,7 @@ if is_torch_available(): "MegatronBertForSequenceClassification", "MegatronBertForTokenClassification", "MegatronBertModel", + "MegatronBertPreTrainedModel", ] ) _import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) @@ -923,7 +935,7 @@ if is_torch_available(): ] ) _import_structure["models.pegasus"].extend( - ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"] + ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"] ) _import_structure["models.prophetnet"].extend( [ @@ -936,7 +948,9 @@ if is_torch_available(): "ProphetNetPreTrainedModel", ] ) - _import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]) + _import_structure["models.rag"].extend( + ["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"] + ) _import_structure["models.reformer"].extend( [ "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -947,6 +961,7 @@ if is_torch_available(): "ReformerLayer", "ReformerModel", "ReformerModelWithLMHead", + "ReformerPreTrainedModel", ] ) _import_structure["models.retribert"].extend( @@ -962,6 +977,7 @@ if is_torch_available(): "RobertaForSequenceClassification", "RobertaForTokenClassification", "RobertaModel", + "RobertaPreTrainedModel", ] ) _import_structure["models.roformer"].extend( @@ -984,6 +1000,7 @@ if is_torch_available(): "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", "Speech2TextForConditionalGeneration", "Speech2TextModel", + "Speech2TextPreTrainedModel", ] ) _import_structure["models.squeezebert"].extend( @@ -1016,6 +1033,7 @@ if is_torch_available(): "TapasForQuestionAnswering", "TapasForSequenceClassification", "TapasModel", + "TapasPreTrainedModel", ] ) _import_structure["models.transfo_xl"].extend( @@ -1197,9 +1215,11 @@ if is_tf_available(): "TFBertPreTrainedModel", ] ) - _import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]) + _import_structure["models.blenderbot"].extend( + ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"] + ) _import_structure["models.blenderbot_small"].extend( - ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"] + ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"] ) _import_structure["models.camembert"].extend( [ @@ -1281,6 +1301,7 @@ if is_tf_available(): "TFFlaubertForSequenceClassification", "TFFlaubertForTokenClassification", "TFFlaubertModel", + "TFFlaubertPreTrainedModel", "TFFlaubertWithLMHeadModel", ] ) @@ -1295,6 +1316,7 @@ if is_tf_available(): "TFFunnelForSequenceClassification", "TFFunnelForTokenClassification", "TFFunnelModel", + "TFFunnelPreTrainedModel", ] ) _import_structure["models.gpt2"].extend( @@ -1329,6 +1351,7 @@ if is_tf_available(): "TFLongformerForSequenceClassification", "TFLongformerForTokenClassification", "TFLongformerModel", + "TFLongformerPreTrainedModel", "TFLongformerSelfAttention", ] ) @@ -1342,8 +1365,10 @@ if is_tf_available(): "TFLxmertVisualFeatureEncoder", ] ) - _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"]) - _import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"]) + _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]) + _import_structure["models.mbart"].extend( + ["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"] + ) _import_structure["models.mobilebert"].extend( [ "TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", @@ -1384,10 +1409,13 @@ if is_tf_available(): "TFOpenAIGPTPreTrainedModel", ] ) - _import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"]) + _import_structure["models.pegasus"].extend( + ["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"] + ) _import_structure["models.rag"].extend( [ "TFRagModel", + "TFRagPreTrainedModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration", ] @@ -1538,6 +1566,7 @@ if is_flax_available(): "FlaxBartForQuestionAnswering", "FlaxBartForSequenceClassification", "FlaxBartModel", + "FlaxBartPreTrainedModel", ] ) _import_structure["models.bert"].extend( @@ -1570,7 +1599,9 @@ if is_flax_available(): "FlaxCLIPModel", "FlaxCLIPPreTrainedModel", "FlaxCLIPTextModel", + "FlaxCLIPTextPreTrainedModel", "FlaxCLIPVisionModel", + "FlaxCLIPVisionPreTrainedModel", ] ) _import_structure["models.electra"].extend( @@ -1585,7 +1616,7 @@ if is_flax_available(): "FlaxElectraPreTrainedModel", ] ) - _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]) + _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]) _import_structure["models.roberta"].extend( [ "FlaxRobertaForMaskedLM", @@ -1597,8 +1628,8 @@ if is_flax_available(): "FlaxRobertaPreTrainedModel", ] ) - _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"]) - _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"]) + _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"]) + _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"]) else: from .utils import dummy_flax_objects @@ -1949,6 +1980,7 @@ if TYPE_CHECKING: DetrForObjectDetection, DetrForSegmentation, DetrModel, + DetrPreTrainedModel, ) else: from .utils.dummy_timm_objects import * @@ -2074,6 +2106,7 @@ if TYPE_CHECKING: from .models.bert_generation import ( BertGenerationDecoder, BertGenerationEncoder, + BertGenerationPreTrainedModel, load_tf_weights_in_bert_generation, ) from .models.big_bird import ( @@ -2097,18 +2130,21 @@ if TYPE_CHECKING: BigBirdPegasusForQuestionAnswering, BigBirdPegasusForSequenceClassification, BigBirdPegasusModel, + BigBirdPegasusPreTrainedModel, ) from .models.blenderbot import ( BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotForCausalLM, BlenderbotForConditionalGeneration, BlenderbotModel, + BlenderbotPreTrainedModel, ) from .models.blenderbot_small import ( BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, BlenderbotSmallForCausalLM, BlenderbotSmallForConditionalGeneration, BlenderbotSmallModel, + BlenderbotSmallPreTrainedModel, ) from .models.camembert import ( CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2226,6 +2262,7 @@ if TYPE_CHECKING: FunnelForSequenceClassification, FunnelForTokenClassification, FunnelModel, + FunnelPreTrainedModel, load_tf_weights_in_funnel, ) from .models.gpt2 import ( @@ -2267,6 +2304,7 @@ if TYPE_CHECKING: LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, + LayoutLMPreTrainedModel, ) from .models.led import ( LED_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2274,6 +2312,7 @@ if TYPE_CHECKING: LEDForQuestionAnswering, LEDForSequenceClassification, LEDModel, + LEDPreTrainedModel, ) from .models.longformer import ( LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2283,6 +2322,7 @@ if TYPE_CHECKING: LongformerForSequenceClassification, LongformerForTokenClassification, LongformerModel, + LongformerPreTrainedModel, LongformerSelfAttention, ) from .models.luke import ( @@ -2302,7 +2342,12 @@ if TYPE_CHECKING: LxmertVisualFeatureEncoder, LxmertXLayer, ) - from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model + from .models.m2m_100 import ( + M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, + M2M100ForConditionalGeneration, + M2M100Model, + M2M100PreTrainedModel, + ) from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.mbart import ( MBartForCausalLM, @@ -2310,6 +2355,7 @@ if TYPE_CHECKING: MBartForQuestionAnswering, MBartForSequenceClassification, MBartModel, + MBartPreTrainedModel, ) from .models.megatron_bert import ( MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2322,6 +2368,7 @@ if TYPE_CHECKING: MegatronBertForSequenceClassification, MegatronBertForTokenClassification, MegatronBertModel, + MegatronBertPreTrainedModel, ) from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mobilebert import ( @@ -2359,7 +2406,12 @@ if TYPE_CHECKING: OpenAIGPTPreTrainedModel, load_tf_weights_in_openai_gpt, ) - from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel + from .models.pegasus import ( + PegasusForCausalLM, + PegasusForConditionalGeneration, + PegasusModel, + PegasusPreTrainedModel, + ) from .models.prophetnet import ( PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, ProphetNetDecoder, @@ -2369,7 +2421,7 @@ if TYPE_CHECKING: ProphetNetModel, ProphetNetPreTrainedModel, ) - from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration + from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration from .models.reformer import ( REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ReformerAttention, @@ -2379,6 +2431,7 @@ if TYPE_CHECKING: ReformerLayer, ReformerModel, ReformerModelWithLMHead, + ReformerPreTrainedModel, ) from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel from .models.roberta import ( @@ -2390,6 +2443,7 @@ if TYPE_CHECKING: RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, + RobertaPreTrainedModel, ) from .models.roformer import ( ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2408,6 +2462,7 @@ if TYPE_CHECKING: SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, Speech2TextForConditionalGeneration, Speech2TextModel, + Speech2TextPreTrainedModel, ) from .models.squeezebert import ( SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2434,6 +2489,7 @@ if TYPE_CHECKING: TapasForQuestionAnswering, TapasForSequenceClassification, TapasModel, + TapasPreTrainedModel, ) from .models.transfo_xl import ( TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2600,8 +2656,16 @@ if TYPE_CHECKING: TFBertModel, TFBertPreTrainedModel, ) - from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel - from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel + from .models.blenderbot import ( + TFBlenderbotForConditionalGeneration, + TFBlenderbotModel, + TFBlenderbotPreTrainedModel, + ) + from .models.blenderbot_small import ( + TFBlenderbotSmallForConditionalGeneration, + TFBlenderbotSmallModel, + TFBlenderbotSmallPreTrainedModel, + ) from .models.camembert import ( TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFCamembertForMaskedLM, @@ -2669,6 +2733,7 @@ if TYPE_CHECKING: TFFlaubertForSequenceClassification, TFFlaubertForTokenClassification, TFFlaubertModel, + TFFlaubertPreTrainedModel, TFFlaubertWithLMHeadModel, ) from .models.funnel import ( @@ -2681,6 +2746,7 @@ if TYPE_CHECKING: TFFunnelForSequenceClassification, TFFunnelForTokenClassification, TFFunnelModel, + TFFunnelPreTrainedModel, ) from .models.gpt2 import ( TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, @@ -2700,6 +2766,7 @@ if TYPE_CHECKING: TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerModel, + TFLongformerPreTrainedModel, TFLongformerSelfAttention, ) from .models.lxmert import ( @@ -2710,8 +2777,8 @@ if TYPE_CHECKING: TFLxmertPreTrainedModel, TFLxmertVisualFeatureEncoder, ) - from .models.marian import TFMarianModel, TFMarianMTModel - from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel + from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel + from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel from .models.mobilebert import ( TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TFMobileBertForMaskedLM, @@ -2746,8 +2813,8 @@ if TYPE_CHECKING: TFOpenAIGPTModel, TFOpenAIGPTPreTrainedModel, ) - from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel - from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration + from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel + from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .models.roberta import ( TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TFRobertaForMaskedLM, @@ -2878,6 +2945,7 @@ if TYPE_CHECKING: FlaxBartForQuestionAnswering, FlaxBartForSequenceClassification, FlaxBartModel, + FlaxBartPreTrainedModel, ) from .models.bert import ( FlaxBertForMaskedLM, @@ -2900,7 +2968,14 @@ if TYPE_CHECKING: FlaxBigBirdModel, FlaxBigBirdPreTrainedModel, ) - from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel + from .models.clip import ( + FlaxCLIPModel, + FlaxCLIPPreTrainedModel, + FlaxCLIPTextModel, + FlaxCLIPTextPreTrainedModel, + FlaxCLIPVisionModel, + FlaxCLIPVisionPreTrainedModel, + ) from .models.electra import ( FlaxElectraForMaskedLM, FlaxElectraForMultipleChoice, @@ -2911,7 +2986,7 @@ if TYPE_CHECKING: FlaxElectraModel, FlaxElectraPreTrainedModel, ) - from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model + from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel from .models.roberta import ( FlaxRobertaForMaskedLM, FlaxRobertaForMultipleChoice, @@ -2921,8 +2996,8 @@ if TYPE_CHECKING: FlaxRobertaModel, FlaxRobertaPreTrainedModel, ) - from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model - from .models.vit import FlaxViTForImageClassification, FlaxViTModel + from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel + from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel else: # Import the same objects as dummies to get them in the namespace. # They will raise an import error if the user tries to instantiate / use them. diff --git a/src/transformers/models/bart/__init__.py b/src/transformers/models/bart/__init__.py index 529f2cf20c..c0a135ecc4 100644 --- a/src/transformers/models/bart/__init__.py +++ b/src/transformers/models/bart/__init__.py @@ -55,6 +55,7 @@ if is_flax_available(): "FlaxBartForQuestionAnswering", "FlaxBartForSequenceClassification", "FlaxBartModel", + "FlaxBartPreTrainedModel", ] if TYPE_CHECKING: @@ -85,6 +86,7 @@ if TYPE_CHECKING: FlaxBartForQuestionAnswering, FlaxBartForSequenceClassification, FlaxBartModel, + FlaxBartPreTrainedModel, ) else: diff --git a/src/transformers/models/bert_generation/__init__.py b/src/transformers/models/bert_generation/__init__.py index edbaf705eb..8d4bba925b 100644 --- a/src/transformers/models/bert_generation/__init__.py +++ b/src/transformers/models/bert_generation/__init__.py @@ -32,6 +32,7 @@ if is_torch_available(): _import_structure["modeling_bert_generation"] = [ "BertGenerationDecoder", "BertGenerationEncoder", + "BertGenerationPreTrainedModel", "load_tf_weights_in_bert_generation", ] @@ -46,6 +47,7 @@ if TYPE_CHECKING: from .modeling_bert_generation import ( BertGenerationDecoder, BertGenerationEncoder, + BertGenerationPreTrainedModel, load_tf_weights_in_bert_generation, ) diff --git a/src/transformers/models/blenderbot/__init__.py b/src/transformers/models/blenderbot/__init__.py index daf0b3dc4e..c6652f118f 100644 --- a/src/transformers/models/blenderbot/__init__.py +++ b/src/transformers/models/blenderbot/__init__.py @@ -37,7 +37,11 @@ if is_torch_available(): if is_tf_available(): - _import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] + _import_structure["modeling_tf_blenderbot"] = [ + "TFBlenderbotForConditionalGeneration", + "TFBlenderbotModel", + "TFBlenderbotPreTrainedModel", + ] if TYPE_CHECKING: @@ -54,7 +58,11 @@ if TYPE_CHECKING: ) if is_tf_available(): - from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel + from .modeling_tf_blenderbot import ( + TFBlenderbotForConditionalGeneration, + TFBlenderbotModel, + TFBlenderbotPreTrainedModel, + ) else: import importlib diff --git a/src/transformers/models/blenderbot_small/__init__.py b/src/transformers/models/blenderbot_small/__init__.py index a40ab18ff1..dd170ccbe9 100644 --- a/src/transformers/models/blenderbot_small/__init__.py +++ b/src/transformers/models/blenderbot_small/__init__.py @@ -38,6 +38,7 @@ if is_tf_available(): _import_structure["modeling_tf_blenderbot_small"] = [ "TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", + "TFBlenderbotSmallPreTrainedModel", ] if TYPE_CHECKING: @@ -54,7 +55,11 @@ if TYPE_CHECKING: ) if is_tf_available(): - from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel + from .modeling_tf_blenderbot_small import ( + TFBlenderbotSmallForConditionalGeneration, + TFBlenderbotSmallModel, + TFBlenderbotSmallPreTrainedModel, + ) else: import importlib diff --git a/src/transformers/models/clip/__init__.py b/src/transformers/models/clip/__init__.py index d3fda176f6..1bef0ee311 100644 --- a/src/transformers/models/clip/__init__.py +++ b/src/transformers/models/clip/__init__.py @@ -52,7 +52,9 @@ if is_flax_available(): "FlaxCLIPModel", "FlaxCLIPPreTrainedModel", "FlaxCLIPTextModel", + "FlaxCLIPTextPreTrainedModel", "FlaxCLIPVisionModel", + "FlaxCLIPVisionPreTrainedModel", ] @@ -77,7 +79,14 @@ if TYPE_CHECKING: ) if is_flax_available(): - from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel + from .modeling_flax_clip import ( + FlaxCLIPModel, + FlaxCLIPPreTrainedModel, + FlaxCLIPTextModel, + FlaxCLIPTextPreTrainedModel, + FlaxCLIPVisionModel, + FlaxCLIPVisionPreTrainedModel, + ) else: diff --git a/src/transformers/models/flaubert/__init__.py b/src/transformers/models/flaubert/__init__.py index 8c1c319322..8b15adc331 100644 --- a/src/transformers/models/flaubert/__init__.py +++ b/src/transformers/models/flaubert/__init__.py @@ -46,6 +46,7 @@ if is_tf_available(): "TFFlaubertForSequenceClassification", "TFFlaubertForTokenClassification", "TFFlaubertModel", + "TFFlaubertPreTrainedModel", "TFFlaubertWithLMHeadModel", ] @@ -74,6 +75,7 @@ if TYPE_CHECKING: TFFlaubertForSequenceClassification, TFFlaubertForTokenClassification, TFFlaubertModel, + TFFlaubertPreTrainedModel, TFFlaubertWithLMHeadModel, ) diff --git a/src/transformers/models/funnel/__init__.py b/src/transformers/models/funnel/__init__.py index 363df7e557..39fdda301b 100644 --- a/src/transformers/models/funnel/__init__.py +++ b/src/transformers/models/funnel/__init__.py @@ -41,6 +41,7 @@ if is_torch_available(): "FunnelForSequenceClassification", "FunnelForTokenClassification", "FunnelModel", + "FunnelPreTrainedModel", "load_tf_weights_in_funnel", ] @@ -55,6 +56,7 @@ if is_tf_available(): "TFFunnelForSequenceClassification", "TFFunnelForTokenClassification", "TFFunnelModel", + "TFFunnelPreTrainedModel", ] @@ -76,6 +78,7 @@ if TYPE_CHECKING: FunnelForSequenceClassification, FunnelForTokenClassification, FunnelModel, + FunnelPreTrainedModel, load_tf_weights_in_funnel, ) @@ -90,6 +93,7 @@ if TYPE_CHECKING: TFFunnelForSequenceClassification, TFFunnelForTokenClassification, TFFunnelModel, + TFFunnelPreTrainedModel, ) else: diff --git a/src/transformers/models/gpt2/__init__.py b/src/transformers/models/gpt2/__init__.py index e0bf154f75..d157b5bb5e 100644 --- a/src/transformers/models/gpt2/__init__.py +++ b/src/transformers/models/gpt2/__init__.py @@ -58,7 +58,7 @@ if is_tf_available(): ] if is_flax_available(): - _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"] + _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"] if TYPE_CHECKING: from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config @@ -90,7 +90,7 @@ if TYPE_CHECKING: ) if is_flax_available(): - from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model + from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel else: import importlib diff --git a/src/transformers/models/layoutlm/__init__.py b/src/transformers/models/layoutlm/__init__.py index 3551891891..0b58954d23 100644 --- a/src/transformers/models/layoutlm/__init__.py +++ b/src/transformers/models/layoutlm/__init__.py @@ -38,6 +38,7 @@ if is_torch_available(): "LayoutLMForSequenceClassification", "LayoutLMForTokenClassification", "LayoutLMModel", + "LayoutLMPreTrainedModel", ] if is_tf_available(): @@ -66,6 +67,7 @@ if TYPE_CHECKING: LayoutLMForSequenceClassification, LayoutLMForTokenClassification, LayoutLMModel, + LayoutLMPreTrainedModel, ) if is_tf_available(): from .modeling_tf_layoutlm import ( diff --git a/src/transformers/models/longformer/__init__.py b/src/transformers/models/longformer/__init__.py index 8cdae7c88f..31beb4d3a4 100644 --- a/src/transformers/models/longformer/__init__.py +++ b/src/transformers/models/longformer/__init__.py @@ -38,6 +38,7 @@ if is_torch_available(): "LongformerForSequenceClassification", "LongformerForTokenClassification", "LongformerModel", + "LongformerPreTrainedModel", "LongformerSelfAttention", ] @@ -50,6 +51,7 @@ if is_tf_available(): "TFLongformerForSequenceClassification", "TFLongformerForTokenClassification", "TFLongformerModel", + "TFLongformerPreTrainedModel", "TFLongformerSelfAttention", ] @@ -70,6 +72,7 @@ if TYPE_CHECKING: LongformerForSequenceClassification, LongformerForTokenClassification, LongformerModel, + LongformerPreTrainedModel, LongformerSelfAttention, ) @@ -82,6 +85,7 @@ if TYPE_CHECKING: TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerModel, + TFLongformerPreTrainedModel, TFLongformerSelfAttention, ) diff --git a/src/transformers/models/marian/__init__.py b/src/transformers/models/marian/__init__.py index 4ec04e192a..a2d95d2da6 100644 --- a/src/transformers/models/marian/__init__.py +++ b/src/transformers/models/marian/__init__.py @@ -43,7 +43,7 @@ if is_torch_available(): ] if is_tf_available(): - _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] + _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"] if TYPE_CHECKING: @@ -62,7 +62,7 @@ if TYPE_CHECKING: ) if is_tf_available(): - from .modeling_tf_marian import TFMarianModel, TFMarianMTModel + from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel else: import importlib diff --git a/src/transformers/models/mbart/__init__.py b/src/transformers/models/mbart/__init__.py index 3367c3c43b..414d33a9fa 100644 --- a/src/transformers/models/mbart/__init__.py +++ b/src/transformers/models/mbart/__init__.py @@ -50,7 +50,11 @@ if is_torch_available(): ] if is_tf_available(): - _import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"] + _import_structure["modeling_tf_mbart"] = [ + "TFMBartForConditionalGeneration", + "TFMBartModel", + "TFMBartPreTrainedModel", + ] if TYPE_CHECKING: @@ -76,7 +80,7 @@ if TYPE_CHECKING: ) if is_tf_available(): - from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel + from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel else: import importlib diff --git a/src/transformers/models/megatron_bert/__init__.py b/src/transformers/models/megatron_bert/__init__.py index 714f1b1ecc..e3d83cb79c 100644 --- a/src/transformers/models/megatron_bert/__init__.py +++ b/src/transformers/models/megatron_bert/__init__.py @@ -36,6 +36,7 @@ if is_torch_available(): "MegatronBertForSequenceClassification", "MegatronBertForTokenClassification", "MegatronBertModel", + "MegatronBertPreTrainedModel", ] if TYPE_CHECKING: @@ -53,6 +54,7 @@ if TYPE_CHECKING: MegatronBertForSequenceClassification, MegatronBertForTokenClassification, MegatronBertModel, + MegatronBertPreTrainedModel, ) else: diff --git a/src/transformers/models/pegasus/__init__.py b/src/transformers/models/pegasus/__init__.py index daecd7825b..ac71aeebc2 100644 --- a/src/transformers/models/pegasus/__init__.py +++ b/src/transformers/models/pegasus/__init__.py @@ -46,7 +46,11 @@ if is_torch_available(): ] if is_tf_available(): - _import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"] + _import_structure["modeling_tf_pegasus"] = [ + "TFPegasusForConditionalGeneration", + "TFPegasusModel", + "TFPegasusPreTrainedModel", + ] if TYPE_CHECKING: @@ -68,7 +72,7 @@ if TYPE_CHECKING: ) if is_tf_available(): - from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel + from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel else: import importlib diff --git a/src/transformers/models/rag/__init__.py b/src/transformers/models/rag/__init__.py index 0c96db8756..644768a4e8 100644 --- a/src/transformers/models/rag/__init__.py +++ b/src/transformers/models/rag/__init__.py @@ -28,10 +28,20 @@ _import_structure = { } if is_torch_available(): - _import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] + _import_structure["modeling_rag"] = [ + "RagModel", + "RagPreTrainedModel", + "RagSequenceForGeneration", + "RagTokenForGeneration", + ] if is_tf_available(): - _import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] + _import_structure["modeling_tf_rag"] = [ + "TFRagModel", + "TFRagPreTrainedModel", + "TFRagSequenceForGeneration", + "TFRagTokenForGeneration", + ] if TYPE_CHECKING: @@ -40,10 +50,15 @@ if TYPE_CHECKING: from .tokenization_rag import RagTokenizer if is_torch_available(): - from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration + from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration if is_tf_available(): - from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration + from .modeling_tf_rag import ( + TFRagModel, + TFRagPreTrainedModel, + TFRagSequenceForGeneration, + TFRagTokenForGeneration, + ) else: import importlib diff --git a/src/transformers/models/reformer/__init__.py b/src/transformers/models/reformer/__init__.py index 63e393c499..d255ce60b9 100644 --- a/src/transformers/models/reformer/__init__.py +++ b/src/transformers/models/reformer/__init__.py @@ -41,6 +41,7 @@ if is_torch_available(): "ReformerLayer", "ReformerModel", "ReformerModelWithLMHead", + "ReformerPreTrainedModel", ] @@ -63,6 +64,7 @@ if TYPE_CHECKING: ReformerLayer, ReformerModel, ReformerModelWithLMHead, + ReformerPreTrainedModel, ) else: diff --git a/src/transformers/models/roberta/__init__.py b/src/transformers/models/roberta/__init__.py index 2194a2decf..b4f1833d0e 100644 --- a/src/transformers/models/roberta/__init__.py +++ b/src/transformers/models/roberta/__init__.py @@ -45,6 +45,7 @@ if is_torch_available(): "RobertaForSequenceClassification", "RobertaForTokenClassification", "RobertaModel", + "RobertaPreTrainedModel", ] if is_tf_available(): @@ -89,6 +90,7 @@ if TYPE_CHECKING: RobertaForSequenceClassification, RobertaForTokenClassification, RobertaModel, + RobertaPreTrainedModel, ) if is_tf_available(): diff --git a/src/transformers/models/tapas/__init__.py b/src/transformers/models/tapas/__init__.py index 76a649df1f..e88943c4f7 100644 --- a/src/transformers/models/tapas/__init__.py +++ b/src/transformers/models/tapas/__init__.py @@ -33,6 +33,7 @@ if is_torch_available(): "TapasForQuestionAnswering", "TapasForSequenceClassification", "TapasModel", + "TapasPreTrainedModel", ] @@ -47,6 +48,7 @@ if TYPE_CHECKING: TapasForQuestionAnswering, TapasForSequenceClassification, TapasModel, + TapasPreTrainedModel, ) else: diff --git a/src/transformers/models/vit/__init__.py b/src/transformers/models/vit/__init__.py index eb9c8f4308..d731eb1d67 100644 --- a/src/transformers/models/vit/__init__.py +++ b/src/transformers/models/vit/__init__.py @@ -37,7 +37,11 @@ if is_torch_available(): if is_flax_available(): - _import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"] + _import_structure["modeling_flax_vit"] = [ + "FlaxViTForImageClassification", + "FlaxViTModel", + "FlaxViTPreTrainedModel", + ] if TYPE_CHECKING: from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig @@ -54,7 +58,7 @@ if TYPE_CHECKING: ) if is_flax_available(): - from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel + from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel else: diff --git a/src/transformers/utils/dummy_flax_objects.py b/src/transformers/utils/dummy_flax_objects.py index b6d6490559..e4a56113d2 100644 --- a/src/transformers/utils/dummy_flax_objects.py +++ b/src/transformers/utils/dummy_flax_objects.py @@ -244,6 +244,15 @@ class FlaxBartModel: requires_backends(cls, ["flax"]) +class FlaxBartPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxBertForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -412,6 +421,15 @@ class FlaxCLIPTextModel: requires_backends(cls, ["flax"]) +class FlaxCLIPTextPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxCLIPVisionModel: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -421,6 +439,15 @@ class FlaxCLIPVisionModel: requires_backends(cls, ["flax"]) +class FlaxCLIPVisionPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxElectraForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -507,6 +534,15 @@ class FlaxGPT2Model: requires_backends(cls, ["flax"]) +class FlaxGPT2PreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxRobertaForMaskedLM: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -588,6 +624,15 @@ class FlaxT5Model: requires_backends(cls, ["flax"]) +class FlaxT5PreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) + + class FlaxViTForImageClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["flax"]) @@ -600,3 +645,12 @@ class FlaxViTModel: @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["flax"]) + + +class FlaxViTPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["flax"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["flax"]) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index c8ce871ea3..50e2b43180 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -692,6 +692,15 @@ class BertGenerationEncoder: requires_backends(self, ["torch"]) +class BertGenerationPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def load_tf_weights_in_bert_generation(*args, **kwargs): requires_backends(load_tf_weights_in_bert_generation, ["torch"]) @@ -833,6 +842,15 @@ class BigBirdPegasusModel: requires_backends(cls, ["torch"]) +class BigBirdPegasusPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -863,6 +881,15 @@ class BlenderbotModel: requires_backends(cls, ["torch"]) +class BlenderbotPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -893,6 +920,15 @@ class BlenderbotSmallModel: requires_backends(cls, ["torch"]) +class BlenderbotSmallPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1610,6 +1646,15 @@ class FunnelModel: requires_backends(cls, ["torch"]) +class FunnelPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + def load_tf_weights_in_funnel(*args, **kwargs): requires_backends(load_tf_weights_in_funnel, ["torch"]) @@ -1840,6 +1885,15 @@ class LayoutLMModel: requires_backends(cls, ["torch"]) +class LayoutLMPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + LED_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1879,6 +1933,15 @@ class LEDModel: requires_backends(cls, ["torch"]) +class LEDPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1936,6 +1999,15 @@ class LongformerModel: requires_backends(cls, ["torch"]) +class LongformerPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class LongformerSelfAttention: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -2045,6 +2117,15 @@ class M2M100Model: requires_backends(cls, ["torch"]) +class M2M100PreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class MarianForCausalLM: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -2117,6 +2198,15 @@ class MBartModel: requires_backends(cls, ["torch"]) +class MBartPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -2193,6 +2283,15 @@ class MegatronBertModel: requires_backends(cls, ["torch"]) +class MegatronBertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class MMBTForClassification: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -2474,6 +2573,15 @@ class PegasusModel: requires_backends(cls, ["torch"]) +class PegasusPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -2532,6 +2640,15 @@ class RagModel: requires_backends(cls, ["torch"]) +class RagPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class RagSequenceForGeneration: def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) @@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead: requires_backends(cls, ["torch"]) +class ReformerPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -2687,6 +2813,15 @@ class RobertaModel: requires_backends(cls, ["torch"]) +class RobertaPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -2792,6 +2927,15 @@ class Speech2TextModel: requires_backends(cls, ["torch"]) +class Speech2TextPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -2945,6 +3089,15 @@ class TapasModel: requires_backends(cls, ["torch"]) +class TapasPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None diff --git a/src/transformers/utils/dummy_tf_objects.py b/src/transformers/utils/dummy_tf_objects.py index e7ecc731cf..24e686f984 100644 --- a/src/transformers/utils/dummy_tf_objects.py +++ b/src/transformers/utils/dummy_tf_objects.py @@ -431,6 +431,15 @@ class TFBlenderbotModel: requires_backends(cls, ["tf"]) +class TFBlenderbotPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFBlenderbotSmallForConditionalGeneration: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -449,6 +458,15 @@ class TFBlenderbotSmallModel: requires_backends(cls, ["tf"]) +class TFBlenderbotSmallPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -845,6 +863,15 @@ class TFFlaubertModel: requires_backends(cls, ["tf"]) +class TFFlaubertPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFFlaubertWithLMHeadModel: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -925,6 +952,15 @@ class TFFunnelModel: requires_backends(cls, ["tf"]) +class TFFunnelPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1062,6 +1098,15 @@ class TFLongformerModel: requires_backends(cls, ["tf"]) +class TFLongformerPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFLongformerSelfAttention: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -1121,6 +1166,15 @@ class TFMarianMTModel: requires_backends(cls, ["tf"]) +class TFMarianPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFMBartForConditionalGeneration: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -1139,6 +1193,15 @@ class TFMBartModel: requires_backends(cls, ["tf"]) +class TFMBartPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None @@ -1389,6 +1452,15 @@ class TFPegasusModel: requires_backends(cls, ["tf"]) +class TFPegasusPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFRagModel: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) @@ -1398,6 +1470,15 @@ class TFRagModel: requires_backends(cls, ["tf"]) +class TFRagPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["tf"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["tf"]) + + class TFRagSequenceForGeneration: def __init__(self, *args, **kwargs): requires_backends(self, ["tf"]) diff --git a/src/transformers/utils/dummy_timm_and_vision_objects.py b/src/transformers/utils/dummy_timm_and_vision_objects.py index a1da2d14be..6a92c8dc27 100644 --- a/src/transformers/utils/dummy_timm_and_vision_objects.py +++ b/src/transformers/utils/dummy_timm_and_vision_objects.py @@ -30,3 +30,12 @@ class DetrModel: @classmethod def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["timm", "vision"]) + + +class DetrPreTrainedModel: + def __init__(self, *args, **kwargs): + requires_backends(self, ["timm", "vision"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["timm", "vision"]) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py index 2480c461be..764a2586ef 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py @@ -52,6 +52,7 @@ "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}Model", + "{{cookiecutter.camelcase_modelname}}PreTrainedModel", ] ) {% endif -%} @@ -120,6 +121,7 @@ {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}Model, + {{cookiecutter.camelcase_modelname}}PreTrainedModel, ) {% endif -%} # End. diff --git a/utils/check_repo.py b/utils/check_repo.py index 23285c9355..244bd20185 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers" PATH_TO_TESTS = "tests" PATH_TO_DOC = "docs/source" +# Update this list with models that are supposed to be private. +PRIVATE_MODELS = [ + "DPRSpanPredictor", + "T5Stack", + "TFDPRSpanPredictor", +] + # Update this list for models that are not tested with a comment explaining the reason it should not be. # Being in this list is an exception and should **not** be the rule. -IGNORE_NON_TESTED = [ +IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ # models to ignore for not tested "BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model. @@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [ "PegasusEncoder", # Building part of bigger (tested) model. "PegasusDecoderWrapper", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model. - "DPRSpanPredictor", # Building part of bigger (tested) model. "ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "ReformerForMaskedLM", # Needs to be setup as decoder. - "T5Stack", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (tested) model. - "TFDPRSpanPredictor", # Building part of bigger (tested) model. "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFRobertaForMultipleChoice", # TODO: fix "SeparableConv1D", # Building part of bigger (tested) model. @@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [ # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and # should **not** be the rule. -IGNORE_NON_AUTO_CONFIGURED = [ +IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [ # models to ignore for model xxx mapping "CLIPTextModel", "CLIPVisionModel", @@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [ "FlaxCLIPVisionModel", "DetrForSegmentation", "DPRReader", - "DPRSpanPredictor", "FlaubertForQuestionAnswering", "GPT2DoubleHeadsModel", "LukeForEntityClassification", @@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ "RagModel", "RagSequenceForGeneration", "RagTokenForGeneration", - "T5Stack", "TFDPRReader", - "TFDPRSpanPredictor", "TFGPT2DoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel", "TFRagModel", @@ -173,12 +174,12 @@ def get_model_modules(): return modules -def get_models(module): +def get_models(module, include_pretrained=False): """Get the objects in module that are models.""" models = [] model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel) for attr_name in dir(module): - if "Pretrained" in attr_name or "PreTrained" in attr_name: + if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name): continue attr = getattr(module, attr_name) if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: @@ -186,6 +187,36 @@ def get_models(module): return models +def is_a_private_model(model): + """Returns True if the model should not be in the main init.""" + if model in PRIVATE_MODELS: + return True + + # Wrapper, Encoder and Decoder are all privates + if model.endswith("Wrapper"): + return True + if model.endswith("Encoder"): + return True + if model.endswith("Decoder"): + return True + return False + + +def check_models_are_in_init(): + """Checks all models defined in the library are in the main init.""" + models_not_in_init = [] + dir_transformers = dir(transformers) + for module in get_model_modules(): + models_not_in_init += [ + model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers + ] + + # Remove private models + models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)] + if len(models_not_in_init) > 0: + raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.") + + # If some test_modeling files should be ignored when checking models are all tested, they should be added in the # nested list _ignore_files of this function. def get_model_test_files(): @@ -229,6 +260,7 @@ def find_tested_models(test_file): def check_models_are_tested(module, test_file): """Check models defined in module are tested in test_file.""" + # XxxPreTrainedModel are not tested defined_models = get_models(module) tested_models = find_tested_models(test_file) if tested_models is None: @@ -515,6 +547,8 @@ def check_all_objects_are_documented(): def check_repo_quality(): """Check all models are properly tested and documented.""" + print("Checking all models are public.") + check_models_are_in_init() print("Checking all models are properly tested.") check_all_decorator_order() check_all_models_are_tested()