Add all XxxPreTrainedModel to the main init (#12314)

* Add all XxxPreTrainedModel to the main init

* Add to template

* Add to template bis

* Add FlaxT5
This commit is contained in:
Sylvain Gugger
2021-06-23 10:40:54 -04:00
committed by GitHub
parent 53c60babe4
commit 9eda6b52e2
26 changed files with 532 additions and 51 deletions

View File

@@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
"DetrForObjectDetection", "DetrForObjectDetection",
"DetrForSegmentation", "DetrForSegmentation",
"DetrModel", "DetrModel",
"DetrPreTrainedModel",
] ]
) )
else: else:
@@ -570,6 +571,7 @@ if is_torch_available():
[ [
"BertGenerationDecoder", "BertGenerationDecoder",
"BertGenerationEncoder", "BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation", "load_tf_weights_in_bert_generation",
] ]
) )
@@ -597,6 +599,7 @@ if is_torch_available():
"BigBirdPegasusForQuestionAnswering", "BigBirdPegasusForQuestionAnswering",
"BigBirdPegasusForSequenceClassification", "BigBirdPegasusForSequenceClassification",
"BigBirdPegasusModel", "BigBirdPegasusModel",
"BigBirdPegasusPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot"].extend( _import_structure["models.blenderbot"].extend(
@@ -605,6 +608,7 @@ if is_torch_available():
"BlenderbotForCausalLM", "BlenderbotForCausalLM",
"BlenderbotForConditionalGeneration", "BlenderbotForConditionalGeneration",
"BlenderbotModel", "BlenderbotModel",
"BlenderbotPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
@@ -613,6 +617,7 @@ if is_torch_available():
"BlenderbotSmallForCausalLM", "BlenderbotSmallForCausalLM",
"BlenderbotSmallForConditionalGeneration", "BlenderbotSmallForConditionalGeneration",
"BlenderbotSmallModel", "BlenderbotSmallModel",
"BlenderbotSmallPreTrainedModel",
] ]
) )
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
@@ -754,6 +759,7 @@ if is_torch_available():
"FunnelForSequenceClassification", "FunnelForSequenceClassification",
"FunnelForTokenClassification", "FunnelForTokenClassification",
"FunnelModel", "FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel", "load_tf_weights_in_funnel",
] ]
) )
@@ -805,6 +811,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel",
] ]
) )
_import_structure["models.led"].extend( _import_structure["models.led"].extend(
@@ -814,6 +821,7 @@ if is_torch_available():
"LEDForQuestionAnswering", "LEDForQuestionAnswering",
"LEDForSequenceClassification", "LEDForSequenceClassification",
"LEDModel", "LEDModel",
"LEDPreTrainedModel",
] ]
) )
_import_structure["models.longformer"].extend( _import_structure["models.longformer"].extend(
@@ -825,6 +833,7 @@ if is_torch_available():
"LongformerForSequenceClassification", "LongformerForSequenceClassification",
"LongformerForTokenClassification", "LongformerForTokenClassification",
"LongformerModel", "LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention", "LongformerSelfAttention",
] ]
) )
@@ -854,6 +863,7 @@ if is_torch_available():
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST", "M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
"M2M100ForConditionalGeneration", "M2M100ForConditionalGeneration",
"M2M100Model", "M2M100Model",
"M2M100PreTrainedModel",
] ]
) )
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"]) _import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
@@ -864,6 +874,7 @@ if is_torch_available():
"MBartForQuestionAnswering", "MBartForQuestionAnswering",
"MBartForSequenceClassification", "MBartForSequenceClassification",
"MBartModel", "MBartModel",
"MBartPreTrainedModel",
] ]
) )
_import_structure["models.megatron_bert"].extend( _import_structure["models.megatron_bert"].extend(
@@ -878,6 +889,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification", "MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification", "MegatronBertForTokenClassification",
"MegatronBertModel", "MegatronBertModel",
"MegatronBertPreTrainedModel",
] ]
) )
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) _import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
@@ -923,7 +935,7 @@ if is_torch_available():
] ]
) )
_import_structure["models.pegasus"].extend( _import_structure["models.pegasus"].extend(
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"] ["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
) )
_import_structure["models.prophetnet"].extend( _import_structure["models.prophetnet"].extend(
[ [
@@ -936,7 +948,9 @@ if is_torch_available():
"ProphetNetPreTrainedModel", "ProphetNetPreTrainedModel",
] ]
) )
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]) _import_structure["models.rag"].extend(
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
)
_import_structure["models.reformer"].extend( _import_structure["models.reformer"].extend(
[ [
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -947,6 +961,7 @@ if is_torch_available():
"ReformerLayer", "ReformerLayer",
"ReformerModel", "ReformerModel",
"ReformerModelWithLMHead", "ReformerModelWithLMHead",
"ReformerPreTrainedModel",
] ]
) )
_import_structure["models.retribert"].extend( _import_structure["models.retribert"].extend(
@@ -962,6 +977,7 @@ if is_torch_available():
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
"RobertaForTokenClassification", "RobertaForTokenClassification",
"RobertaModel", "RobertaModel",
"RobertaPreTrainedModel",
] ]
) )
_import_structure["models.roformer"].extend( _import_structure["models.roformer"].extend(
@@ -984,6 +1000,7 @@ if is_torch_available():
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST", "SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
"Speech2TextForConditionalGeneration", "Speech2TextForConditionalGeneration",
"Speech2TextModel", "Speech2TextModel",
"Speech2TextPreTrainedModel",
] ]
) )
_import_structure["models.squeezebert"].extend( _import_structure["models.squeezebert"].extend(
@@ -1016,6 +1033,7 @@ if is_torch_available():
"TapasForQuestionAnswering", "TapasForQuestionAnswering",
"TapasForSequenceClassification", "TapasForSequenceClassification",
"TapasModel", "TapasModel",
"TapasPreTrainedModel",
] ]
) )
_import_structure["models.transfo_xl"].extend( _import_structure["models.transfo_xl"].extend(
@@ -1197,9 +1215,11 @@ if is_tf_available():
"TFBertPreTrainedModel", "TFBertPreTrainedModel",
] ]
) )
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]) _import_structure["models.blenderbot"].extend(
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
)
_import_structure["models.blenderbot_small"].extend( _import_structure["models.blenderbot_small"].extend(
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"] ["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
) )
_import_structure["models.camembert"].extend( _import_structure["models.camembert"].extend(
[ [
@@ -1281,6 +1301,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification", "TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification", "TFFlaubertForTokenClassification",
"TFFlaubertModel", "TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel", "TFFlaubertWithLMHeadModel",
] ]
) )
@@ -1295,6 +1316,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification", "TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification", "TFFunnelForTokenClassification",
"TFFunnelModel", "TFFunnelModel",
"TFFunnelPreTrainedModel",
] ]
) )
_import_structure["models.gpt2"].extend( _import_structure["models.gpt2"].extend(
@@ -1329,6 +1351,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification", "TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification", "TFLongformerForTokenClassification",
"TFLongformerModel", "TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention", "TFLongformerSelfAttention",
] ]
) )
@@ -1342,8 +1365,10 @@ if is_tf_available():
"TFLxmertVisualFeatureEncoder", "TFLxmertVisualFeatureEncoder",
] ]
) )
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"]) _import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"]) _import_structure["models.mbart"].extend(
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
)
_import_structure["models.mobilebert"].extend( _import_structure["models.mobilebert"].extend(
[ [
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
@@ -1384,10 +1409,13 @@ if is_tf_available():
"TFOpenAIGPTPreTrainedModel", "TFOpenAIGPTPreTrainedModel",
] ]
) )
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"]) _import_structure["models.pegasus"].extend(
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
)
_import_structure["models.rag"].extend( _import_structure["models.rag"].extend(
[ [
"TFRagModel", "TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration", "TFRagSequenceForGeneration",
"TFRagTokenForGeneration", "TFRagTokenForGeneration",
] ]
@@ -1538,6 +1566,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering", "FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification", "FlaxBartForSequenceClassification",
"FlaxBartModel", "FlaxBartModel",
"FlaxBartPreTrainedModel",
] ]
) )
_import_structure["models.bert"].extend( _import_structure["models.bert"].extend(
@@ -1570,7 +1599,9 @@ if is_flax_available():
"FlaxCLIPModel", "FlaxCLIPModel",
"FlaxCLIPPreTrainedModel", "FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
] ]
) )
_import_structure["models.electra"].extend( _import_structure["models.electra"].extend(
@@ -1585,7 +1616,7 @@ if is_flax_available():
"FlaxElectraPreTrainedModel", "FlaxElectraPreTrainedModel",
] ]
) )
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]) _import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
@@ -1597,8 +1628,8 @@ if is_flax_available():
"FlaxRobertaPreTrainedModel", "FlaxRobertaPreTrainedModel",
] ]
) )
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"]) _import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"]) _import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
else: else:
from .utils import dummy_flax_objects from .utils import dummy_flax_objects
@@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
DetrForObjectDetection, DetrForObjectDetection,
DetrForSegmentation, DetrForSegmentation,
DetrModel, DetrModel,
DetrPreTrainedModel,
) )
else: else:
from .utils.dummy_timm_objects import * from .utils.dummy_timm_objects import *
@@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
from .models.bert_generation import ( from .models.bert_generation import (
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation, load_tf_weights_in_bert_generation,
) )
from .models.big_bird import ( from .models.big_bird import (
@@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
BigBirdPegasusForQuestionAnswering, BigBirdPegasusForQuestionAnswering,
BigBirdPegasusForSequenceClassification, BigBirdPegasusForSequenceClassification,
BigBirdPegasusModel, BigBirdPegasusModel,
BigBirdPegasusPreTrainedModel,
) )
from .models.blenderbot import ( from .models.blenderbot import (
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotForCausalLM, BlenderbotForCausalLM,
BlenderbotForConditionalGeneration, BlenderbotForConditionalGeneration,
BlenderbotModel, BlenderbotModel,
BlenderbotPreTrainedModel,
) )
from .models.blenderbot_small import ( from .models.blenderbot_small import (
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
BlenderbotSmallForCausalLM, BlenderbotSmallForCausalLM,
BlenderbotSmallForConditionalGeneration, BlenderbotSmallForConditionalGeneration,
BlenderbotSmallModel, BlenderbotSmallModel,
BlenderbotSmallPreTrainedModel,
) )
from .models.camembert import ( from .models.camembert import (
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification, FunnelForSequenceClassification,
FunnelForTokenClassification, FunnelForTokenClassification,
FunnelModel, FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel, load_tf_weights_in_funnel,
) )
from .models.gpt2 import ( from .models.gpt2 import (
@@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
LayoutLMPreTrainedModel,
) )
from .models.led import ( from .models.led import (
LED_PRETRAINED_MODEL_ARCHIVE_LIST, LED_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
LEDForQuestionAnswering, LEDForQuestionAnswering,
LEDForSequenceClassification, LEDForSequenceClassification,
LEDModel, LEDModel,
LEDPreTrainedModel,
) )
from .models.longformer import ( from .models.longformer import (
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerModel, LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention, LongformerSelfAttention,
) )
from .models.luke import ( from .models.luke import (
@@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
LxmertVisualFeatureEncoder, LxmertVisualFeatureEncoder,
LxmertXLayer, LxmertXLayer,
) )
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model from .models.m2m_100 import (
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
M2M100ForConditionalGeneration,
M2M100Model,
M2M100PreTrainedModel,
)
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
from .models.mbart import ( from .models.mbart import (
MBartForCausalLM, MBartForCausalLM,
@@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
MBartForQuestionAnswering, MBartForQuestionAnswering,
MBartForSequenceClassification, MBartForSequenceClassification,
MBartModel, MBartModel,
MBartPreTrainedModel,
) )
from .models.megatron_bert import ( from .models.megatron_bert import (
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification, MegatronBertForSequenceClassification,
MegatronBertForTokenClassification, MegatronBertForTokenClassification,
MegatronBertModel, MegatronBertModel,
MegatronBertPreTrainedModel,
) )
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
from .models.mobilebert import ( from .models.mobilebert import (
@@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
OpenAIGPTPreTrainedModel, OpenAIGPTPreTrainedModel,
load_tf_weights_in_openai_gpt, load_tf_weights_in_openai_gpt,
) )
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel from .models.pegasus import (
PegasusForCausalLM,
PegasusForConditionalGeneration,
PegasusModel,
PegasusPreTrainedModel,
)
from .models.prophetnet import ( from .models.prophetnet import (
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
ProphetNetDecoder, ProphetNetDecoder,
@@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
ProphetNetModel, ProphetNetModel,
ProphetNetPreTrainedModel, ProphetNetPreTrainedModel,
) )
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
from .models.reformer import ( from .models.reformer import (
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
ReformerAttention, ReformerAttention,
@@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerPreTrainedModel,
) )
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
from .models.roberta import ( from .models.roberta import (
@@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
RobertaPreTrainedModel,
) )
from .models.roformer import ( from .models.roformer import (
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST, SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
Speech2TextForConditionalGeneration, Speech2TextForConditionalGeneration,
Speech2TextModel, Speech2TextModel,
Speech2TextPreTrainedModel,
) )
from .models.squeezebert import ( from .models.squeezebert import (
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasPreTrainedModel,
) )
from .models.transfo_xl import ( from .models.transfo_xl import (
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
TFBertModel, TFBertModel,
TFBertPreTrainedModel, TFBertPreTrainedModel,
) )
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .models.blenderbot import (
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
from .models.blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
from .models.camembert import ( from .models.camembert import (
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFCamembertForMaskedLM, TFCamembertForMaskedLM,
@@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification, TFFlaubertForTokenClassification,
TFFlaubertModel, TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )
from .models.funnel import ( from .models.funnel import (
@@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification, TFFunnelForSequenceClassification,
TFFunnelForTokenClassification, TFFunnelForTokenClassification,
TFFunnelModel, TFFunnelModel,
TFFunnelPreTrainedModel,
) )
from .models.gpt2 import ( from .models.gpt2 import (
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
@@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification, TFLongformerForSequenceClassification,
TFLongformerForTokenClassification, TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )
from .models.lxmert import ( from .models.lxmert import (
@@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
TFLxmertPreTrainedModel, TFLxmertPreTrainedModel,
TFLxmertVisualFeatureEncoder, TFLxmertVisualFeatureEncoder,
) )
from .models.marian import TFMarianModel, TFMarianMTModel from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
from .models.mobilebert import ( from .models.mobilebert import (
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFMobileBertForMaskedLM, TFMobileBertForMaskedLM,
@@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
TFOpenAIGPTModel, TFOpenAIGPTModel,
TFOpenAIGPTPreTrainedModel, TFOpenAIGPTPreTrainedModel,
) )
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
from .models.roberta import ( from .models.roberta import (
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
TFRobertaForMaskedLM, TFRobertaForMaskedLM,
@@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering, FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification, FlaxBartForSequenceClassification,
FlaxBartModel, FlaxBartModel,
FlaxBartPreTrainedModel,
) )
from .models.bert import ( from .models.bert import (
FlaxBertForMaskedLM, FlaxBertForMaskedLM,
@@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
FlaxBigBirdModel, FlaxBigBirdModel,
FlaxBigBirdPreTrainedModel, FlaxBigBirdPreTrainedModel,
) )
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel from .models.clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
from .models.electra import ( from .models.electra import (
FlaxElectraForMaskedLM, FlaxElectraForMaskedLM,
FlaxElectraForMultipleChoice, FlaxElectraForMultipleChoice,
@@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
FlaxElectraModel, FlaxElectraModel,
FlaxElectraPreTrainedModel, FlaxElectraPreTrainedModel,
) )
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
@@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
FlaxRobertaModel, FlaxRobertaModel,
FlaxRobertaPreTrainedModel, FlaxRobertaPreTrainedModel,
) )
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
from .models.vit import FlaxViTForImageClassification, FlaxViTModel from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else: else:
# Import the same objects as dummies to get them in the namespace. # Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them. # They will raise an import error if the user tries to instantiate / use them.

View File

@@ -55,6 +55,7 @@ if is_flax_available():
"FlaxBartForQuestionAnswering", "FlaxBartForQuestionAnswering",
"FlaxBartForSequenceClassification", "FlaxBartForSequenceClassification",
"FlaxBartModel", "FlaxBartModel",
"FlaxBartPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -85,6 +86,7 @@ if TYPE_CHECKING:
FlaxBartForQuestionAnswering, FlaxBartForQuestionAnswering,
FlaxBartForSequenceClassification, FlaxBartForSequenceClassification,
FlaxBartModel, FlaxBartModel,
FlaxBartPreTrainedModel,
) )
else: else:

View File

@@ -32,6 +32,7 @@ if is_torch_available():
_import_structure["modeling_bert_generation"] = [ _import_structure["modeling_bert_generation"] = [
"BertGenerationDecoder", "BertGenerationDecoder",
"BertGenerationEncoder", "BertGenerationEncoder",
"BertGenerationPreTrainedModel",
"load_tf_weights_in_bert_generation", "load_tf_weights_in_bert_generation",
] ]
@@ -46,6 +47,7 @@ if TYPE_CHECKING:
from .modeling_bert_generation import ( from .modeling_bert_generation import (
BertGenerationDecoder, BertGenerationDecoder,
BertGenerationEncoder, BertGenerationEncoder,
BertGenerationPreTrainedModel,
load_tf_weights_in_bert_generation, load_tf_weights_in_bert_generation,
) )

View File

@@ -37,7 +37,11 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"] _import_structure["modeling_tf_blenderbot"] = [
"TFBlenderbotForConditionalGeneration",
"TFBlenderbotModel",
"TFBlenderbotPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -54,7 +58,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel from .modeling_tf_blenderbot import (
TFBlenderbotForConditionalGeneration,
TFBlenderbotModel,
TFBlenderbotPreTrainedModel,
)
else: else:
import importlib import importlib

View File

@@ -38,6 +38,7 @@ if is_tf_available():
_import_structure["modeling_tf_blenderbot_small"] = [ _import_structure["modeling_tf_blenderbot_small"] = [
"TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallForConditionalGeneration",
"TFBlenderbotSmallModel", "TFBlenderbotSmallModel",
"TFBlenderbotSmallPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -54,7 +55,11 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel from .modeling_tf_blenderbot_small import (
TFBlenderbotSmallForConditionalGeneration,
TFBlenderbotSmallModel,
TFBlenderbotSmallPreTrainedModel,
)
else: else:
import importlib import importlib

View File

@@ -52,7 +52,9 @@ if is_flax_available():
"FlaxCLIPModel", "FlaxCLIPModel",
"FlaxCLIPPreTrainedModel", "FlaxCLIPPreTrainedModel",
"FlaxCLIPTextModel", "FlaxCLIPTextModel",
"FlaxCLIPTextPreTrainedModel",
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"FlaxCLIPVisionPreTrainedModel",
] ]
@@ -77,7 +79,14 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel from .modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPPreTrainedModel,
FlaxCLIPTextModel,
FlaxCLIPTextPreTrainedModel,
FlaxCLIPVisionModel,
FlaxCLIPVisionPreTrainedModel,
)
else: else:

View File

@@ -46,6 +46,7 @@ if is_tf_available():
"TFFlaubertForSequenceClassification", "TFFlaubertForSequenceClassification",
"TFFlaubertForTokenClassification", "TFFlaubertForTokenClassification",
"TFFlaubertModel", "TFFlaubertModel",
"TFFlaubertPreTrainedModel",
"TFFlaubertWithLMHeadModel", "TFFlaubertWithLMHeadModel",
] ]
@@ -74,6 +75,7 @@ if TYPE_CHECKING:
TFFlaubertForSequenceClassification, TFFlaubertForSequenceClassification,
TFFlaubertForTokenClassification, TFFlaubertForTokenClassification,
TFFlaubertModel, TFFlaubertModel,
TFFlaubertPreTrainedModel,
TFFlaubertWithLMHeadModel, TFFlaubertWithLMHeadModel,
) )

View File

@@ -41,6 +41,7 @@ if is_torch_available():
"FunnelForSequenceClassification", "FunnelForSequenceClassification",
"FunnelForTokenClassification", "FunnelForTokenClassification",
"FunnelModel", "FunnelModel",
"FunnelPreTrainedModel",
"load_tf_weights_in_funnel", "load_tf_weights_in_funnel",
] ]
@@ -55,6 +56,7 @@ if is_tf_available():
"TFFunnelForSequenceClassification", "TFFunnelForSequenceClassification",
"TFFunnelForTokenClassification", "TFFunnelForTokenClassification",
"TFFunnelModel", "TFFunnelModel",
"TFFunnelPreTrainedModel",
] ]
@@ -76,6 +78,7 @@ if TYPE_CHECKING:
FunnelForSequenceClassification, FunnelForSequenceClassification,
FunnelForTokenClassification, FunnelForTokenClassification,
FunnelModel, FunnelModel,
FunnelPreTrainedModel,
load_tf_weights_in_funnel, load_tf_weights_in_funnel,
) )
@@ -90,6 +93,7 @@ if TYPE_CHECKING:
TFFunnelForSequenceClassification, TFFunnelForSequenceClassification,
TFFunnelForTokenClassification, TFFunnelForTokenClassification,
TFFunnelModel, TFFunnelModel,
TFFunnelPreTrainedModel,
) )
else: else:

View File

@@ -58,7 +58,7 @@ if is_tf_available():
] ]
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"] _import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
@@ -90,7 +90,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
else: else:
import importlib import importlib

View File

@@ -38,6 +38,7 @@ if is_torch_available():
"LayoutLMForSequenceClassification", "LayoutLMForSequenceClassification",
"LayoutLMForTokenClassification", "LayoutLMForTokenClassification",
"LayoutLMModel", "LayoutLMModel",
"LayoutLMPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
LayoutLMForSequenceClassification, LayoutLMForSequenceClassification,
LayoutLMForTokenClassification, LayoutLMForTokenClassification,
LayoutLMModel, LayoutLMModel,
LayoutLMPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_layoutlm import ( from .modeling_tf_layoutlm import (

View File

@@ -38,6 +38,7 @@ if is_torch_available():
"LongformerForSequenceClassification", "LongformerForSequenceClassification",
"LongformerForTokenClassification", "LongformerForTokenClassification",
"LongformerModel", "LongformerModel",
"LongformerPreTrainedModel",
"LongformerSelfAttention", "LongformerSelfAttention",
] ]
@@ -50,6 +51,7 @@ if is_tf_available():
"TFLongformerForSequenceClassification", "TFLongformerForSequenceClassification",
"TFLongformerForTokenClassification", "TFLongformerForTokenClassification",
"TFLongformerModel", "TFLongformerModel",
"TFLongformerPreTrainedModel",
"TFLongformerSelfAttention", "TFLongformerSelfAttention",
] ]
@@ -70,6 +72,7 @@ if TYPE_CHECKING:
LongformerForSequenceClassification, LongformerForSequenceClassification,
LongformerForTokenClassification, LongformerForTokenClassification,
LongformerModel, LongformerModel,
LongformerPreTrainedModel,
LongformerSelfAttention, LongformerSelfAttention,
) )
@@ -82,6 +85,7 @@ if TYPE_CHECKING:
TFLongformerForSequenceClassification, TFLongformerForSequenceClassification,
TFLongformerForTokenClassification, TFLongformerForTokenClassification,
TFLongformerModel, TFLongformerModel,
TFLongformerPreTrainedModel,
TFLongformerSelfAttention, TFLongformerSelfAttention,
) )

View File

@@ -43,7 +43,7 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"] _import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -62,7 +62,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
else: else:
import importlib import importlib

View File

@@ -50,7 +50,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"] _import_structure["modeling_tf_mbart"] = [
"TFMBartForConditionalGeneration",
"TFMBartModel",
"TFMBartPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -76,7 +80,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
else: else:
import importlib import importlib

View File

@@ -36,6 +36,7 @@ if is_torch_available():
"MegatronBertForSequenceClassification", "MegatronBertForSequenceClassification",
"MegatronBertForTokenClassification", "MegatronBertForTokenClassification",
"MegatronBertModel", "MegatronBertModel",
"MegatronBertPreTrainedModel",
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -53,6 +54,7 @@ if TYPE_CHECKING:
MegatronBertForSequenceClassification, MegatronBertForSequenceClassification,
MegatronBertForTokenClassification, MegatronBertForTokenClassification,
MegatronBertModel, MegatronBertModel,
MegatronBertPreTrainedModel,
) )
else: else:

View File

@@ -46,7 +46,11 @@ if is_torch_available():
] ]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"] _import_structure["modeling_tf_pegasus"] = [
"TFPegasusForConditionalGeneration",
"TFPegasusModel",
"TFPegasusPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -68,7 +72,7 @@ if TYPE_CHECKING:
) )
if is_tf_available(): if is_tf_available():
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
else: else:
import importlib import importlib

View File

@@ -28,10 +28,20 @@ _import_structure = {
} }
if is_torch_available(): if is_torch_available():
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"] _import_structure["modeling_rag"] = [
"RagModel",
"RagPreTrainedModel",
"RagSequenceForGeneration",
"RagTokenForGeneration",
]
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"] _import_structure["modeling_tf_rag"] = [
"TFRagModel",
"TFRagPreTrainedModel",
"TFRagSequenceForGeneration",
"TFRagTokenForGeneration",
]
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -40,10 +50,15 @@ if TYPE_CHECKING:
from .tokenization_rag import RagTokenizer from .tokenization_rag import RagTokenizer
if is_torch_available(): if is_torch_available():
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
if is_tf_available(): if is_tf_available():
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration from .modeling_tf_rag import (
TFRagModel,
TFRagPreTrainedModel,
TFRagSequenceForGeneration,
TFRagTokenForGeneration,
)
else: else:
import importlib import importlib

View File

@@ -41,6 +41,7 @@ if is_torch_available():
"ReformerLayer", "ReformerLayer",
"ReformerModel", "ReformerModel",
"ReformerModelWithLMHead", "ReformerModelWithLMHead",
"ReformerPreTrainedModel",
] ]
@@ -63,6 +64,7 @@ if TYPE_CHECKING:
ReformerLayer, ReformerLayer,
ReformerModel, ReformerModel,
ReformerModelWithLMHead, ReformerModelWithLMHead,
ReformerPreTrainedModel,
) )
else: else:

View File

@@ -45,6 +45,7 @@ if is_torch_available():
"RobertaForSequenceClassification", "RobertaForSequenceClassification",
"RobertaForTokenClassification", "RobertaForTokenClassification",
"RobertaModel", "RobertaModel",
"RobertaPreTrainedModel",
] ]
if is_tf_available(): if is_tf_available():
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
RobertaForSequenceClassification, RobertaForSequenceClassification,
RobertaForTokenClassification, RobertaForTokenClassification,
RobertaModel, RobertaModel,
RobertaPreTrainedModel,
) )
if is_tf_available(): if is_tf_available():

View File

@@ -33,6 +33,7 @@ if is_torch_available():
"TapasForQuestionAnswering", "TapasForQuestionAnswering",
"TapasForSequenceClassification", "TapasForSequenceClassification",
"TapasModel", "TapasModel",
"TapasPreTrainedModel",
] ]
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasPreTrainedModel,
) )
else: else:

View File

@@ -37,7 +37,11 @@ if is_torch_available():
if is_flax_available(): if is_flax_available():
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"] _import_structure["modeling_flax_vit"] = [
"FlaxViTForImageClassification",
"FlaxViTModel",
"FlaxViTPreTrainedModel",
]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
@@ -54,7 +58,7 @@ if TYPE_CHECKING:
) )
if is_flax_available(): if is_flax_available():
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
else: else:

View File

@@ -244,6 +244,15 @@ class FlaxBartModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxBertForMaskedLM: class FlaxBertForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@@ -412,6 +421,15 @@ class FlaxCLIPTextModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxCLIPTextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxCLIPVisionModel: class FlaxCLIPVisionModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@@ -421,6 +439,15 @@ class FlaxCLIPVisionModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxCLIPVisionPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxElectraForMaskedLM: class FlaxElectraForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@@ -507,6 +534,15 @@ class FlaxGPT2Model:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxGPT2PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxRobertaForMaskedLM: class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@@ -588,6 +624,15 @@ class FlaxT5Model:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxT5PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxViTForImageClassification: class FlaxViTForImageClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
@@ -600,3 +645,12 @@ class FlaxViTModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxViTPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])

View File

@@ -692,6 +692,15 @@ class BertGenerationEncoder:
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
class BertGenerationPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_bert_generation(*args, **kwargs): def load_tf_weights_in_bert_generation(*args, **kwargs):
requires_backends(load_tf_weights_in_bert_generation, ["torch"]) requires_backends(load_tf_weights_in_bert_generation, ["torch"])
@@ -833,6 +842,15 @@ class BigBirdPegasusModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BigBirdPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -863,6 +881,15 @@ class BlenderbotModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -893,6 +920,15 @@ class BlenderbotSmallModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class BlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1610,6 +1646,15 @@ class FunnelModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class FunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
def load_tf_weights_in_funnel(*args, **kwargs): def load_tf_weights_in_funnel(*args, **kwargs):
requires_backends(load_tf_weights_in_funnel, ["torch"]) requires_backends(load_tf_weights_in_funnel, ["torch"])
@@ -1840,6 +1885,15 @@ class LayoutLMModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LayoutLMPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1879,6 +1933,15 @@ class LEDModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LEDPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1936,6 +1999,15 @@ class LongformerModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class LongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class LongformerSelfAttention: class LongformerSelfAttention:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@@ -2045,6 +2117,15 @@ class M2M100Model:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class M2M100PreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MarianForCausalLM: class MarianForCausalLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@@ -2117,6 +2198,15 @@ class MBartModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class MBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2193,6 +2283,15 @@ class MegatronBertModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class MegatronBertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class MMBTForClassification: class MMBTForClassification:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@@ -2474,6 +2573,15 @@ class PegasusModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class PegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2532,6 +2640,15 @@ class RagModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class RagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class RagSequenceForGeneration: class RagSequenceForGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
@@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class ReformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2687,6 +2813,15 @@ class RobertaModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class RobertaPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2792,6 +2927,15 @@ class Speech2TextModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class Speech2TextPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -2945,6 +3089,15 @@ class TapasModel:
requires_backends(cls, ["torch"]) requires_backends(cls, ["torch"])
class TapasPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None

View File

@@ -431,6 +431,15 @@ class TFBlenderbotModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFBlenderbotPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFBlenderbotSmallForConditionalGeneration: class TFBlenderbotSmallForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
@@ -449,6 +458,15 @@ class TFBlenderbotSmallModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFBlenderbotSmallPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -845,6 +863,15 @@ class TFFlaubertModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFFlaubertPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFFlaubertWithLMHeadModel: class TFFlaubertWithLMHeadModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
@@ -925,6 +952,15 @@ class TFFunnelModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFFunnelPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1062,6 +1098,15 @@ class TFLongformerModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFLongformerPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFLongformerSelfAttention: class TFLongformerSelfAttention:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
@@ -1121,6 +1166,15 @@ class TFMarianMTModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFMarianPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFMBartForConditionalGeneration: class TFMBartForConditionalGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
@@ -1139,6 +1193,15 @@ class TFMBartModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFMBartPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
@@ -1389,6 +1452,15 @@ class TFPegasusModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFPegasusPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagModel: class TFRagModel:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])
@@ -1398,6 +1470,15 @@ class TFRagModel:
requires_backends(cls, ["tf"]) requires_backends(cls, ["tf"])
class TFRagPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["tf"])
class TFRagSequenceForGeneration: class TFRagSequenceForGeneration:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["tf"]) requires_backends(self, ["tf"])

View File

@@ -30,3 +30,12 @@ class DetrModel:
@classmethod @classmethod
def from_pretrained(cls, *args, **kwargs): def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"]) requires_backends(cls, ["timm", "vision"])
class DetrPreTrainedModel:
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["timm", "vision"])

View File

@@ -52,6 +52,7 @@
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification", "{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model", "{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
] ]
) )
{% endif -%} {% endif -%}
@@ -120,6 +121,7 @@
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% endif -%} {% endif -%}
# End. # End.

View File

@@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
PATH_TO_TESTS = "tests" PATH_TO_TESTS = "tests"
PATH_TO_DOC = "docs/source" PATH_TO_DOC = "docs/source"
# Update this list with models that are supposed to be private.
PRIVATE_MODELS = [
"DPRSpanPredictor",
"T5Stack",
"TFDPRSpanPredictor",
]
# Update this list for models that are not tested with a comment explaining the reason it should not be. # Update this list for models that are not tested with a comment explaining the reason it should not be.
# Being in this list is an exception and should **not** be the rule. # Being in this list is an exception and should **not** be the rule.
IGNORE_NON_TESTED = [ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# models to ignore for not tested # models to ignore for not tested
"BigBirdPegasusEncoder", # Building part of bigger (tested) model. "BigBirdPegasusEncoder", # Building part of bigger (tested) model.
"BigBirdPegasusDecoder", # Building part of bigger (tested) model. "BigBirdPegasusDecoder", # Building part of bigger (tested) model.
@@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
"PegasusEncoder", # Building part of bigger (tested) model. "PegasusEncoder", # Building part of bigger (tested) model.
"PegasusDecoderWrapper", # Building part of bigger (tested) model. "PegasusDecoderWrapper", # Building part of bigger (tested) model.
"DPREncoder", # Building part of bigger (tested) model. "DPREncoder", # Building part of bigger (tested) model.
"DPRSpanPredictor", # Building part of bigger (tested) model.
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
"ReformerForMaskedLM", # Needs to be setup as decoder. "ReformerForMaskedLM", # Needs to be setup as decoder.
"T5Stack", # Building part of bigger (tested) model.
"TFDPREncoder", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (tested) model.
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?) "TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
"TFRobertaForMultipleChoice", # TODO: fix "TFRobertaForMultipleChoice", # TODO: fix
"SeparableConv1D", # Building part of bigger (tested) model. "SeparableConv1D", # Building part of bigger (tested) model.
@@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and # Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
# should **not** be the rule. # should **not** be the rule.
IGNORE_NON_AUTO_CONFIGURED = [ IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
# models to ignore for model xxx mapping # models to ignore for model xxx mapping
"CLIPTextModel", "CLIPTextModel",
"CLIPVisionModel", "CLIPVisionModel",
@@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
"FlaxCLIPVisionModel", "FlaxCLIPVisionModel",
"DetrForSegmentation", "DetrForSegmentation",
"DPRReader", "DPRReader",
"DPRSpanPredictor",
"FlaubertForQuestionAnswering", "FlaubertForQuestionAnswering",
"GPT2DoubleHeadsModel", "GPT2DoubleHeadsModel",
"LukeForEntityClassification", "LukeForEntityClassification",
@@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
"RagModel", "RagModel",
"RagSequenceForGeneration", "RagSequenceForGeneration",
"RagTokenForGeneration", "RagTokenForGeneration",
"T5Stack",
"TFDPRReader", "TFDPRReader",
"TFDPRSpanPredictor",
"TFGPT2DoubleHeadsModel", "TFGPT2DoubleHeadsModel",
"TFOpenAIGPTDoubleHeadsModel", "TFOpenAIGPTDoubleHeadsModel",
"TFRagModel", "TFRagModel",
@@ -173,12 +174,12 @@ def get_model_modules():
return modules return modules
def get_models(module): def get_models(module, include_pretrained=False):
"""Get the objects in module that are models.""" """Get the objects in module that are models."""
models = [] models = []
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel) model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
for attr_name in dir(module): for attr_name in dir(module):
if "Pretrained" in attr_name or "PreTrained" in attr_name: if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
continue continue
attr = getattr(module, attr_name) attr = getattr(module, attr_name)
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__: if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
@@ -186,6 +187,36 @@ def get_models(module):
return models return models
def is_a_private_model(model):
"""Returns True if the model should not be in the main init."""
if model in PRIVATE_MODELS:
return True
# Wrapper, Encoder and Decoder are all privates
if model.endswith("Wrapper"):
return True
if model.endswith("Encoder"):
return True
if model.endswith("Decoder"):
return True
return False
def check_models_are_in_init():
"""Checks all models defined in the library are in the main init."""
models_not_in_init = []
dir_transformers = dir(transformers)
for module in get_model_modules():
models_not_in_init += [
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
]
# Remove private models
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
if len(models_not_in_init) > 0:
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the # If some test_modeling files should be ignored when checking models are all tested, they should be added in the
# nested list _ignore_files of this function. # nested list _ignore_files of this function.
def get_model_test_files(): def get_model_test_files():
@@ -229,6 +260,7 @@ def find_tested_models(test_file):
def check_models_are_tested(module, test_file): def check_models_are_tested(module, test_file):
"""Check models defined in module are tested in test_file.""" """Check models defined in module are tested in test_file."""
# XxxPreTrainedModel are not tested
defined_models = get_models(module) defined_models = get_models(module)
tested_models = find_tested_models(test_file) tested_models = find_tested_models(test_file)
if tested_models is None: if tested_models is None:
@@ -515,6 +547,8 @@ def check_all_objects_are_documented():
def check_repo_quality(): def check_repo_quality():
"""Check all models are properly tested and documented.""" """Check all models are properly tested and documented."""
print("Checking all models are public.")
check_models_are_in_init()
print("Checking all models are properly tested.") print("Checking all models are properly tested.")
check_all_decorator_order() check_all_decorator_order()
check_all_models_are_tested() check_all_models_are_tested()