Add all XxxPreTrainedModel to the main init (#12314)
* Add all XxxPreTrainedModel to the main init * Add to template * Add to template bis * Add FlaxT5
This commit is contained in:
@@ -427,6 +427,7 @@ if is_timm_available() and is_vision_available():
|
|||||||
"DetrForObjectDetection",
|
"DetrForObjectDetection",
|
||||||
"DetrForSegmentation",
|
"DetrForSegmentation",
|
||||||
"DetrModel",
|
"DetrModel",
|
||||||
|
"DetrPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -570,6 +571,7 @@ if is_torch_available():
|
|||||||
[
|
[
|
||||||
"BertGenerationDecoder",
|
"BertGenerationDecoder",
|
||||||
"BertGenerationEncoder",
|
"BertGenerationEncoder",
|
||||||
|
"BertGenerationPreTrainedModel",
|
||||||
"load_tf_weights_in_bert_generation",
|
"load_tf_weights_in_bert_generation",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -597,6 +599,7 @@ if is_torch_available():
|
|||||||
"BigBirdPegasusForQuestionAnswering",
|
"BigBirdPegasusForQuestionAnswering",
|
||||||
"BigBirdPegasusForSequenceClassification",
|
"BigBirdPegasusForSequenceClassification",
|
||||||
"BigBirdPegasusModel",
|
"BigBirdPegasusModel",
|
||||||
|
"BigBirdPegasusPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot"].extend(
|
_import_structure["models.blenderbot"].extend(
|
||||||
@@ -605,6 +608,7 @@ if is_torch_available():
|
|||||||
"BlenderbotForCausalLM",
|
"BlenderbotForCausalLM",
|
||||||
"BlenderbotForConditionalGeneration",
|
"BlenderbotForConditionalGeneration",
|
||||||
"BlenderbotModel",
|
"BlenderbotModel",
|
||||||
|
"BlenderbotPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot_small"].extend(
|
_import_structure["models.blenderbot_small"].extend(
|
||||||
@@ -613,6 +617,7 @@ if is_torch_available():
|
|||||||
"BlenderbotSmallForCausalLM",
|
"BlenderbotSmallForCausalLM",
|
||||||
"BlenderbotSmallForConditionalGeneration",
|
"BlenderbotSmallForConditionalGeneration",
|
||||||
"BlenderbotSmallModel",
|
"BlenderbotSmallModel",
|
||||||
|
"BlenderbotSmallPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.camembert"].extend(
|
_import_structure["models.camembert"].extend(
|
||||||
@@ -754,6 +759,7 @@ if is_torch_available():
|
|||||||
"FunnelForSequenceClassification",
|
"FunnelForSequenceClassification",
|
||||||
"FunnelForTokenClassification",
|
"FunnelForTokenClassification",
|
||||||
"FunnelModel",
|
"FunnelModel",
|
||||||
|
"FunnelPreTrainedModel",
|
||||||
"load_tf_weights_in_funnel",
|
"load_tf_weights_in_funnel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -805,6 +811,7 @@ if is_torch_available():
|
|||||||
"LayoutLMForSequenceClassification",
|
"LayoutLMForSequenceClassification",
|
||||||
"LayoutLMForTokenClassification",
|
"LayoutLMForTokenClassification",
|
||||||
"LayoutLMModel",
|
"LayoutLMModel",
|
||||||
|
"LayoutLMPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.led"].extend(
|
_import_structure["models.led"].extend(
|
||||||
@@ -814,6 +821,7 @@ if is_torch_available():
|
|||||||
"LEDForQuestionAnswering",
|
"LEDForQuestionAnswering",
|
||||||
"LEDForSequenceClassification",
|
"LEDForSequenceClassification",
|
||||||
"LEDModel",
|
"LEDModel",
|
||||||
|
"LEDPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.longformer"].extend(
|
_import_structure["models.longformer"].extend(
|
||||||
@@ -825,6 +833,7 @@ if is_torch_available():
|
|||||||
"LongformerForSequenceClassification",
|
"LongformerForSequenceClassification",
|
||||||
"LongformerForTokenClassification",
|
"LongformerForTokenClassification",
|
||||||
"LongformerModel",
|
"LongformerModel",
|
||||||
|
"LongformerPreTrainedModel",
|
||||||
"LongformerSelfAttention",
|
"LongformerSelfAttention",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -854,6 +863,7 @@ if is_torch_available():
|
|||||||
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"M2M100ForConditionalGeneration",
|
"M2M100ForConditionalGeneration",
|
||||||
"M2M100Model",
|
"M2M100Model",
|
||||||
|
"M2M100PreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
_import_structure["models.marian"].extend(["MarianForCausalLM", "MarianModel", "MarianMTModel"])
|
||||||
@@ -864,6 +874,7 @@ if is_torch_available():
|
|||||||
"MBartForQuestionAnswering",
|
"MBartForQuestionAnswering",
|
||||||
"MBartForSequenceClassification",
|
"MBartForSequenceClassification",
|
||||||
"MBartModel",
|
"MBartModel",
|
||||||
|
"MBartPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.megatron_bert"].extend(
|
_import_structure["models.megatron_bert"].extend(
|
||||||
@@ -878,6 +889,7 @@ if is_torch_available():
|
|||||||
"MegatronBertForSequenceClassification",
|
"MegatronBertForSequenceClassification",
|
||||||
"MegatronBertForTokenClassification",
|
"MegatronBertForTokenClassification",
|
||||||
"MegatronBertModel",
|
"MegatronBertModel",
|
||||||
|
"MegatronBertPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
|
_import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"])
|
||||||
@@ -923,7 +935,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.pegasus"].extend(
|
_import_structure["models.pegasus"].extend(
|
||||||
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel"]
|
["PegasusForCausalLM", "PegasusForConditionalGeneration", "PegasusModel", "PegasusPreTrainedModel"]
|
||||||
)
|
)
|
||||||
_import_structure["models.prophetnet"].extend(
|
_import_structure["models.prophetnet"].extend(
|
||||||
[
|
[
|
||||||
@@ -936,7 +948,9 @@ if is_torch_available():
|
|||||||
"ProphetNetPreTrainedModel",
|
"ProphetNetPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"])
|
_import_structure["models.rag"].extend(
|
||||||
|
["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
||||||
|
)
|
||||||
_import_structure["models.reformer"].extend(
|
_import_structure["models.reformer"].extend(
|
||||||
[
|
[
|
||||||
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -947,6 +961,7 @@ if is_torch_available():
|
|||||||
"ReformerLayer",
|
"ReformerLayer",
|
||||||
"ReformerModel",
|
"ReformerModel",
|
||||||
"ReformerModelWithLMHead",
|
"ReformerModelWithLMHead",
|
||||||
|
"ReformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.retribert"].extend(
|
_import_structure["models.retribert"].extend(
|
||||||
@@ -962,6 +977,7 @@ if is_torch_available():
|
|||||||
"RobertaForSequenceClassification",
|
"RobertaForSequenceClassification",
|
||||||
"RobertaForTokenClassification",
|
"RobertaForTokenClassification",
|
||||||
"RobertaModel",
|
"RobertaModel",
|
||||||
|
"RobertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.roformer"].extend(
|
_import_structure["models.roformer"].extend(
|
||||||
@@ -984,6 +1000,7 @@ if is_torch_available():
|
|||||||
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
"Speech2TextForConditionalGeneration",
|
"Speech2TextForConditionalGeneration",
|
||||||
"Speech2TextModel",
|
"Speech2TextModel",
|
||||||
|
"Speech2TextPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.squeezebert"].extend(
|
_import_structure["models.squeezebert"].extend(
|
||||||
@@ -1016,6 +1033,7 @@ if is_torch_available():
|
|||||||
"TapasForQuestionAnswering",
|
"TapasForQuestionAnswering",
|
||||||
"TapasForSequenceClassification",
|
"TapasForSequenceClassification",
|
||||||
"TapasModel",
|
"TapasModel",
|
||||||
|
"TapasPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.transfo_xl"].extend(
|
_import_structure["models.transfo_xl"].extend(
|
||||||
@@ -1197,9 +1215,11 @@ if is_tf_available():
|
|||||||
"TFBertPreTrainedModel",
|
"TFBertPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.blenderbot"].extend(["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"])
|
_import_structure["models.blenderbot"].extend(
|
||||||
|
["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel", "TFBlenderbotPreTrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.blenderbot_small"].extend(
|
_import_structure["models.blenderbot_small"].extend(
|
||||||
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel"]
|
["TFBlenderbotSmallForConditionalGeneration", "TFBlenderbotSmallModel", "TFBlenderbotSmallPreTrainedModel"]
|
||||||
)
|
)
|
||||||
_import_structure["models.camembert"].extend(
|
_import_structure["models.camembert"].extend(
|
||||||
[
|
[
|
||||||
@@ -1281,6 +1301,7 @@ if is_tf_available():
|
|||||||
"TFFlaubertForSequenceClassification",
|
"TFFlaubertForSequenceClassification",
|
||||||
"TFFlaubertForTokenClassification",
|
"TFFlaubertForTokenClassification",
|
||||||
"TFFlaubertModel",
|
"TFFlaubertModel",
|
||||||
|
"TFFlaubertPreTrainedModel",
|
||||||
"TFFlaubertWithLMHeadModel",
|
"TFFlaubertWithLMHeadModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -1295,6 +1316,7 @@ if is_tf_available():
|
|||||||
"TFFunnelForSequenceClassification",
|
"TFFunnelForSequenceClassification",
|
||||||
"TFFunnelForTokenClassification",
|
"TFFunnelForTokenClassification",
|
||||||
"TFFunnelModel",
|
"TFFunnelModel",
|
||||||
|
"TFFunnelPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.gpt2"].extend(
|
_import_structure["models.gpt2"].extend(
|
||||||
@@ -1329,6 +1351,7 @@ if is_tf_available():
|
|||||||
"TFLongformerForSequenceClassification",
|
"TFLongformerForSequenceClassification",
|
||||||
"TFLongformerForTokenClassification",
|
"TFLongformerForTokenClassification",
|
||||||
"TFLongformerModel",
|
"TFLongformerModel",
|
||||||
|
"TFLongformerPreTrainedModel",
|
||||||
"TFLongformerSelfAttention",
|
"TFLongformerSelfAttention",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -1342,8 +1365,10 @@ if is_tf_available():
|
|||||||
"TFLxmertVisualFeatureEncoder",
|
"TFLxmertVisualFeatureEncoder",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel"])
|
_import_structure["models.marian"].extend(["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"])
|
||||||
_import_structure["models.mbart"].extend(["TFMBartForConditionalGeneration", "TFMBartModel"])
|
_import_structure["models.mbart"].extend(
|
||||||
|
["TFMBartForConditionalGeneration", "TFMBartModel", "TFMBartPreTrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.mobilebert"].extend(
|
_import_structure["models.mobilebert"].extend(
|
||||||
[
|
[
|
||||||
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
"TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
@@ -1384,10 +1409,13 @@ if is_tf_available():
|
|||||||
"TFOpenAIGPTPreTrainedModel",
|
"TFOpenAIGPTPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.pegasus"].extend(["TFPegasusForConditionalGeneration", "TFPegasusModel"])
|
_import_structure["models.pegasus"].extend(
|
||||||
|
["TFPegasusForConditionalGeneration", "TFPegasusModel", "TFPegasusPreTrainedModel"]
|
||||||
|
)
|
||||||
_import_structure["models.rag"].extend(
|
_import_structure["models.rag"].extend(
|
||||||
[
|
[
|
||||||
"TFRagModel",
|
"TFRagModel",
|
||||||
|
"TFRagPreTrainedModel",
|
||||||
"TFRagSequenceForGeneration",
|
"TFRagSequenceForGeneration",
|
||||||
"TFRagTokenForGeneration",
|
"TFRagTokenForGeneration",
|
||||||
]
|
]
|
||||||
@@ -1538,6 +1566,7 @@ if is_flax_available():
|
|||||||
"FlaxBartForQuestionAnswering",
|
"FlaxBartForQuestionAnswering",
|
||||||
"FlaxBartForSequenceClassification",
|
"FlaxBartForSequenceClassification",
|
||||||
"FlaxBartModel",
|
"FlaxBartModel",
|
||||||
|
"FlaxBartPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.bert"].extend(
|
_import_structure["models.bert"].extend(
|
||||||
@@ -1570,7 +1599,9 @@ if is_flax_available():
|
|||||||
"FlaxCLIPModel",
|
"FlaxCLIPModel",
|
||||||
"FlaxCLIPPreTrainedModel",
|
"FlaxCLIPPreTrainedModel",
|
||||||
"FlaxCLIPTextModel",
|
"FlaxCLIPTextModel",
|
||||||
|
"FlaxCLIPTextPreTrainedModel",
|
||||||
"FlaxCLIPVisionModel",
|
"FlaxCLIPVisionModel",
|
||||||
|
"FlaxCLIPVisionPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.electra"].extend(
|
_import_structure["models.electra"].extend(
|
||||||
@@ -1585,7 +1616,7 @@ if is_flax_available():
|
|||||||
"FlaxElectraPreTrainedModel",
|
"FlaxElectraPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model"])
|
_import_structure["models.gpt2"].extend(["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"])
|
||||||
_import_structure["models.roberta"].extend(
|
_import_structure["models.roberta"].extend(
|
||||||
[
|
[
|
||||||
"FlaxRobertaForMaskedLM",
|
"FlaxRobertaForMaskedLM",
|
||||||
@@ -1597,8 +1628,8 @@ if is_flax_available():
|
|||||||
"FlaxRobertaPreTrainedModel",
|
"FlaxRobertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model"])
|
_import_structure["models.t5"].extend(["FlaxT5ForConditionalGeneration", "FlaxT5Model", "FlaxT5PreTrainedModel"])
|
||||||
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel"])
|
_import_structure["models.vit"].extend(["FlaxViTForImageClassification", "FlaxViTModel", "FlaxViTPreTrainedModel"])
|
||||||
else:
|
else:
|
||||||
from .utils import dummy_flax_objects
|
from .utils import dummy_flax_objects
|
||||||
|
|
||||||
@@ -1949,6 +1980,7 @@ if TYPE_CHECKING:
|
|||||||
DetrForObjectDetection,
|
DetrForObjectDetection,
|
||||||
DetrForSegmentation,
|
DetrForSegmentation,
|
||||||
DetrModel,
|
DetrModel,
|
||||||
|
DetrPreTrainedModel,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
from .utils.dummy_timm_objects import *
|
from .utils.dummy_timm_objects import *
|
||||||
@@ -2074,6 +2106,7 @@ if TYPE_CHECKING:
|
|||||||
from .models.bert_generation import (
|
from .models.bert_generation import (
|
||||||
BertGenerationDecoder,
|
BertGenerationDecoder,
|
||||||
BertGenerationEncoder,
|
BertGenerationEncoder,
|
||||||
|
BertGenerationPreTrainedModel,
|
||||||
load_tf_weights_in_bert_generation,
|
load_tf_weights_in_bert_generation,
|
||||||
)
|
)
|
||||||
from .models.big_bird import (
|
from .models.big_bird import (
|
||||||
@@ -2097,18 +2130,21 @@ if TYPE_CHECKING:
|
|||||||
BigBirdPegasusForQuestionAnswering,
|
BigBirdPegasusForQuestionAnswering,
|
||||||
BigBirdPegasusForSequenceClassification,
|
BigBirdPegasusForSequenceClassification,
|
||||||
BigBirdPegasusModel,
|
BigBirdPegasusModel,
|
||||||
|
BigBirdPegasusPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.blenderbot import (
|
from .models.blenderbot import (
|
||||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BlenderbotForCausalLM,
|
BlenderbotForCausalLM,
|
||||||
BlenderbotForConditionalGeneration,
|
BlenderbotForConditionalGeneration,
|
||||||
BlenderbotModel,
|
BlenderbotModel,
|
||||||
|
BlenderbotPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.blenderbot_small import (
|
from .models.blenderbot_small import (
|
||||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BlenderbotSmallForCausalLM,
|
BlenderbotSmallForCausalLM,
|
||||||
BlenderbotSmallForConditionalGeneration,
|
BlenderbotSmallForConditionalGeneration,
|
||||||
BlenderbotSmallModel,
|
BlenderbotSmallModel,
|
||||||
|
BlenderbotSmallPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.camembert import (
|
from .models.camembert import (
|
||||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2226,6 +2262,7 @@ if TYPE_CHECKING:
|
|||||||
FunnelForSequenceClassification,
|
FunnelForSequenceClassification,
|
||||||
FunnelForTokenClassification,
|
FunnelForTokenClassification,
|
||||||
FunnelModel,
|
FunnelModel,
|
||||||
|
FunnelPreTrainedModel,
|
||||||
load_tf_weights_in_funnel,
|
load_tf_weights_in_funnel,
|
||||||
)
|
)
|
||||||
from .models.gpt2 import (
|
from .models.gpt2 import (
|
||||||
@@ -2267,6 +2304,7 @@ if TYPE_CHECKING:
|
|||||||
LayoutLMForSequenceClassification,
|
LayoutLMForSequenceClassification,
|
||||||
LayoutLMForTokenClassification,
|
LayoutLMForTokenClassification,
|
||||||
LayoutLMModel,
|
LayoutLMModel,
|
||||||
|
LayoutLMPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.led import (
|
from .models.led import (
|
||||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
LED_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2274,6 +2312,7 @@ if TYPE_CHECKING:
|
|||||||
LEDForQuestionAnswering,
|
LEDForQuestionAnswering,
|
||||||
LEDForSequenceClassification,
|
LEDForSequenceClassification,
|
||||||
LEDModel,
|
LEDModel,
|
||||||
|
LEDPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.longformer import (
|
from .models.longformer import (
|
||||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2283,6 +2322,7 @@ if TYPE_CHECKING:
|
|||||||
LongformerForSequenceClassification,
|
LongformerForSequenceClassification,
|
||||||
LongformerForTokenClassification,
|
LongformerForTokenClassification,
|
||||||
LongformerModel,
|
LongformerModel,
|
||||||
|
LongformerPreTrainedModel,
|
||||||
LongformerSelfAttention,
|
LongformerSelfAttention,
|
||||||
)
|
)
|
||||||
from .models.luke import (
|
from .models.luke import (
|
||||||
@@ -2302,7 +2342,12 @@ if TYPE_CHECKING:
|
|||||||
LxmertVisualFeatureEncoder,
|
LxmertVisualFeatureEncoder,
|
||||||
LxmertXLayer,
|
LxmertXLayer,
|
||||||
)
|
)
|
||||||
from .models.m2m_100 import M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST, M2M100ForConditionalGeneration, M2M100Model
|
from .models.m2m_100 import (
|
||||||
|
M2M_100_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
M2M100ForConditionalGeneration,
|
||||||
|
M2M100Model,
|
||||||
|
M2M100PreTrainedModel,
|
||||||
|
)
|
||||||
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
from .models.marian import MarianForCausalLM, MarianModel, MarianMTModel
|
||||||
from .models.mbart import (
|
from .models.mbart import (
|
||||||
MBartForCausalLM,
|
MBartForCausalLM,
|
||||||
@@ -2310,6 +2355,7 @@ if TYPE_CHECKING:
|
|||||||
MBartForQuestionAnswering,
|
MBartForQuestionAnswering,
|
||||||
MBartForSequenceClassification,
|
MBartForSequenceClassification,
|
||||||
MBartModel,
|
MBartModel,
|
||||||
|
MBartPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.megatron_bert import (
|
from .models.megatron_bert import (
|
||||||
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2322,6 +2368,7 @@ if TYPE_CHECKING:
|
|||||||
MegatronBertForSequenceClassification,
|
MegatronBertForSequenceClassification,
|
||||||
MegatronBertForTokenClassification,
|
MegatronBertForTokenClassification,
|
||||||
MegatronBertModel,
|
MegatronBertModel,
|
||||||
|
MegatronBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
|
from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings
|
||||||
from .models.mobilebert import (
|
from .models.mobilebert import (
|
||||||
@@ -2359,7 +2406,12 @@ if TYPE_CHECKING:
|
|||||||
OpenAIGPTPreTrainedModel,
|
OpenAIGPTPreTrainedModel,
|
||||||
load_tf_weights_in_openai_gpt,
|
load_tf_weights_in_openai_gpt,
|
||||||
)
|
)
|
||||||
from .models.pegasus import PegasusForCausalLM, PegasusForConditionalGeneration, PegasusModel
|
from .models.pegasus import (
|
||||||
|
PegasusForCausalLM,
|
||||||
|
PegasusForConditionalGeneration,
|
||||||
|
PegasusModel,
|
||||||
|
PegasusPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.prophetnet import (
|
from .models.prophetnet import (
|
||||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
ProphetNetDecoder,
|
ProphetNetDecoder,
|
||||||
@@ -2369,7 +2421,7 @@ if TYPE_CHECKING:
|
|||||||
ProphetNetModel,
|
ProphetNetModel,
|
||||||
ProphetNetPreTrainedModel,
|
ProphetNetPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
from .models.rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||||
from .models.reformer import (
|
from .models.reformer import (
|
||||||
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
ReformerAttention,
|
ReformerAttention,
|
||||||
@@ -2379,6 +2431,7 @@ if TYPE_CHECKING:
|
|||||||
ReformerLayer,
|
ReformerLayer,
|
||||||
ReformerModel,
|
ReformerModel,
|
||||||
ReformerModelWithLMHead,
|
ReformerModelWithLMHead,
|
||||||
|
ReformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
|
from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
@@ -2390,6 +2443,7 @@ if TYPE_CHECKING:
|
|||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
|
RobertaPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.roformer import (
|
from .models.roformer import (
|
||||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2408,6 +2462,7 @@ if TYPE_CHECKING:
|
|||||||
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
Speech2TextForConditionalGeneration,
|
Speech2TextForConditionalGeneration,
|
||||||
Speech2TextModel,
|
Speech2TextModel,
|
||||||
|
Speech2TextPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.squeezebert import (
|
from .models.squeezebert import (
|
||||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2434,6 +2489,7 @@ if TYPE_CHECKING:
|
|||||||
TapasForQuestionAnswering,
|
TapasForQuestionAnswering,
|
||||||
TapasForSequenceClassification,
|
TapasForSequenceClassification,
|
||||||
TapasModel,
|
TapasModel,
|
||||||
|
TapasPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.transfo_xl import (
|
from .models.transfo_xl import (
|
||||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2600,8 +2656,16 @@ if TYPE_CHECKING:
|
|||||||
TFBertModel,
|
TFBertModel,
|
||||||
TFBertPreTrainedModel,
|
TFBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
from .models.blenderbot import (
|
||||||
from .models.blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
TFBlenderbotForConditionalGeneration,
|
||||||
|
TFBlenderbotModel,
|
||||||
|
TFBlenderbotPreTrainedModel,
|
||||||
|
)
|
||||||
|
from .models.blenderbot_small import (
|
||||||
|
TFBlenderbotSmallForConditionalGeneration,
|
||||||
|
TFBlenderbotSmallModel,
|
||||||
|
TFBlenderbotSmallPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.camembert import (
|
from .models.camembert import (
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFCamembertForMaskedLM,
|
TFCamembertForMaskedLM,
|
||||||
@@ -2669,6 +2733,7 @@ if TYPE_CHECKING:
|
|||||||
TFFlaubertForSequenceClassification,
|
TFFlaubertForSequenceClassification,
|
||||||
TFFlaubertForTokenClassification,
|
TFFlaubertForTokenClassification,
|
||||||
TFFlaubertModel,
|
TFFlaubertModel,
|
||||||
|
TFFlaubertPreTrainedModel,
|
||||||
TFFlaubertWithLMHeadModel,
|
TFFlaubertWithLMHeadModel,
|
||||||
)
|
)
|
||||||
from .models.funnel import (
|
from .models.funnel import (
|
||||||
@@ -2681,6 +2746,7 @@ if TYPE_CHECKING:
|
|||||||
TFFunnelForSequenceClassification,
|
TFFunnelForSequenceClassification,
|
||||||
TFFunnelForTokenClassification,
|
TFFunnelForTokenClassification,
|
||||||
TFFunnelModel,
|
TFFunnelModel,
|
||||||
|
TFFunnelPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.gpt2 import (
|
from .models.gpt2 import (
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
@@ -2700,6 +2766,7 @@ if TYPE_CHECKING:
|
|||||||
TFLongformerForSequenceClassification,
|
TFLongformerForSequenceClassification,
|
||||||
TFLongformerForTokenClassification,
|
TFLongformerForTokenClassification,
|
||||||
TFLongformerModel,
|
TFLongformerModel,
|
||||||
|
TFLongformerPreTrainedModel,
|
||||||
TFLongformerSelfAttention,
|
TFLongformerSelfAttention,
|
||||||
)
|
)
|
||||||
from .models.lxmert import (
|
from .models.lxmert import (
|
||||||
@@ -2710,8 +2777,8 @@ if TYPE_CHECKING:
|
|||||||
TFLxmertPreTrainedModel,
|
TFLxmertPreTrainedModel,
|
||||||
TFLxmertVisualFeatureEncoder,
|
TFLxmertVisualFeatureEncoder,
|
||||||
)
|
)
|
||||||
from .models.marian import TFMarianModel, TFMarianMTModel
|
from .models.marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
|
||||||
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel
|
from .models.mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
|
||||||
from .models.mobilebert import (
|
from .models.mobilebert import (
|
||||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFMobileBertForMaskedLM,
|
TFMobileBertForMaskedLM,
|
||||||
@@ -2746,8 +2813,8 @@ if TYPE_CHECKING:
|
|||||||
TFOpenAIGPTModel,
|
TFOpenAIGPTModel,
|
||||||
TFOpenAIGPTPreTrainedModel,
|
TFOpenAIGPTPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
from .models.pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
|
||||||
from .models.rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
from .models.rag import TFRagModel, TFRagPreTrainedModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TFRobertaForMaskedLM,
|
TFRobertaForMaskedLM,
|
||||||
@@ -2878,6 +2945,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxBartForQuestionAnswering,
|
FlaxBartForQuestionAnswering,
|
||||||
FlaxBartForSequenceClassification,
|
FlaxBartForSequenceClassification,
|
||||||
FlaxBartModel,
|
FlaxBartModel,
|
||||||
|
FlaxBartPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.bert import (
|
from .models.bert import (
|
||||||
FlaxBertForMaskedLM,
|
FlaxBertForMaskedLM,
|
||||||
@@ -2900,7 +2968,14 @@ if TYPE_CHECKING:
|
|||||||
FlaxBigBirdModel,
|
FlaxBigBirdModel,
|
||||||
FlaxBigBirdPreTrainedModel,
|
FlaxBigBirdPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
from .models.clip import (
|
||||||
|
FlaxCLIPModel,
|
||||||
|
FlaxCLIPPreTrainedModel,
|
||||||
|
FlaxCLIPTextModel,
|
||||||
|
FlaxCLIPTextPreTrainedModel,
|
||||||
|
FlaxCLIPVisionModel,
|
||||||
|
FlaxCLIPVisionPreTrainedModel,
|
||||||
|
)
|
||||||
from .models.electra import (
|
from .models.electra import (
|
||||||
FlaxElectraForMaskedLM,
|
FlaxElectraForMaskedLM,
|
||||||
FlaxElectraForMultipleChoice,
|
FlaxElectraForMultipleChoice,
|
||||||
@@ -2911,7 +2986,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxElectraModel,
|
FlaxElectraModel,
|
||||||
FlaxElectraPreTrainedModel,
|
FlaxElectraPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
from .models.gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||||
from .models.roberta import (
|
from .models.roberta import (
|
||||||
FlaxRobertaForMaskedLM,
|
FlaxRobertaForMaskedLM,
|
||||||
FlaxRobertaForMultipleChoice,
|
FlaxRobertaForMultipleChoice,
|
||||||
@@ -2921,8 +2996,8 @@ if TYPE_CHECKING:
|
|||||||
FlaxRobertaModel,
|
FlaxRobertaModel,
|
||||||
FlaxRobertaPreTrainedModel,
|
FlaxRobertaPreTrainedModel,
|
||||||
)
|
)
|
||||||
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
|
from .models.t5 import FlaxT5ForConditionalGeneration, FlaxT5Model, FlaxT5PreTrainedModel
|
||||||
from .models.vit import FlaxViTForImageClassification, FlaxViTModel
|
from .models.vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||||
else:
|
else:
|
||||||
# Import the same objects as dummies to get them in the namespace.
|
# Import the same objects as dummies to get them in the namespace.
|
||||||
# They will raise an import error if the user tries to instantiate / use them.
|
# They will raise an import error if the user tries to instantiate / use them.
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ if is_flax_available():
|
|||||||
"FlaxBartForQuestionAnswering",
|
"FlaxBartForQuestionAnswering",
|
||||||
"FlaxBartForSequenceClassification",
|
"FlaxBartForSequenceClassification",
|
||||||
"FlaxBartModel",
|
"FlaxBartModel",
|
||||||
|
"FlaxBartPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -85,6 +86,7 @@ if TYPE_CHECKING:
|
|||||||
FlaxBartForQuestionAnswering,
|
FlaxBartForQuestionAnswering,
|
||||||
FlaxBartForSequenceClassification,
|
FlaxBartForSequenceClassification,
|
||||||
FlaxBartModel,
|
FlaxBartModel,
|
||||||
|
FlaxBartPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
|||||||
_import_structure["modeling_bert_generation"] = [
|
_import_structure["modeling_bert_generation"] = [
|
||||||
"BertGenerationDecoder",
|
"BertGenerationDecoder",
|
||||||
"BertGenerationEncoder",
|
"BertGenerationEncoder",
|
||||||
|
"BertGenerationPreTrainedModel",
|
||||||
"load_tf_weights_in_bert_generation",
|
"load_tf_weights_in_bert_generation",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -46,6 +47,7 @@ if TYPE_CHECKING:
|
|||||||
from .modeling_bert_generation import (
|
from .modeling_bert_generation import (
|
||||||
BertGenerationDecoder,
|
BertGenerationDecoder,
|
||||||
BertGenerationEncoder,
|
BertGenerationEncoder,
|
||||||
|
BertGenerationPreTrainedModel,
|
||||||
load_tf_weights_in_bert_generation,
|
load_tf_weights_in_bert_generation,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -37,7 +37,11 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_blenderbot"] = ["TFBlenderbotForConditionalGeneration", "TFBlenderbotModel"]
|
_import_structure["modeling_tf_blenderbot"] = [
|
||||||
|
"TFBlenderbotForConditionalGeneration",
|
||||||
|
"TFBlenderbotModel",
|
||||||
|
"TFBlenderbotPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -54,7 +58,11 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_blenderbot import TFBlenderbotForConditionalGeneration, TFBlenderbotModel
|
from .modeling_tf_blenderbot import (
|
||||||
|
TFBlenderbotForConditionalGeneration,
|
||||||
|
TFBlenderbotModel,
|
||||||
|
TFBlenderbotPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_tf_available():
|
|||||||
_import_structure["modeling_tf_blenderbot_small"] = [
|
_import_structure["modeling_tf_blenderbot_small"] = [
|
||||||
"TFBlenderbotSmallForConditionalGeneration",
|
"TFBlenderbotSmallForConditionalGeneration",
|
||||||
"TFBlenderbotSmallModel",
|
"TFBlenderbotSmallModel",
|
||||||
|
"TFBlenderbotSmallPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -54,7 +55,11 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_blenderbot_small import TFBlenderbotSmallForConditionalGeneration, TFBlenderbotSmallModel
|
from .modeling_tf_blenderbot_small import (
|
||||||
|
TFBlenderbotSmallForConditionalGeneration,
|
||||||
|
TFBlenderbotSmallModel,
|
||||||
|
TFBlenderbotSmallPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -52,7 +52,9 @@ if is_flax_available():
|
|||||||
"FlaxCLIPModel",
|
"FlaxCLIPModel",
|
||||||
"FlaxCLIPPreTrainedModel",
|
"FlaxCLIPPreTrainedModel",
|
||||||
"FlaxCLIPTextModel",
|
"FlaxCLIPTextModel",
|
||||||
|
"FlaxCLIPTextPreTrainedModel",
|
||||||
"FlaxCLIPVisionModel",
|
"FlaxCLIPVisionModel",
|
||||||
|
"FlaxCLIPVisionPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -77,7 +79,14 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_clip import FlaxCLIPModel, FlaxCLIPPreTrainedModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
from .modeling_flax_clip import (
|
||||||
|
FlaxCLIPModel,
|
||||||
|
FlaxCLIPPreTrainedModel,
|
||||||
|
FlaxCLIPTextModel,
|
||||||
|
FlaxCLIPTextPreTrainedModel,
|
||||||
|
FlaxCLIPVisionModel,
|
||||||
|
FlaxCLIPVisionPreTrainedModel,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ if is_tf_available():
|
|||||||
"TFFlaubertForSequenceClassification",
|
"TFFlaubertForSequenceClassification",
|
||||||
"TFFlaubertForTokenClassification",
|
"TFFlaubertForTokenClassification",
|
||||||
"TFFlaubertModel",
|
"TFFlaubertModel",
|
||||||
|
"TFFlaubertPreTrainedModel",
|
||||||
"TFFlaubertWithLMHeadModel",
|
"TFFlaubertWithLMHeadModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -74,6 +75,7 @@ if TYPE_CHECKING:
|
|||||||
TFFlaubertForSequenceClassification,
|
TFFlaubertForSequenceClassification,
|
||||||
TFFlaubertForTokenClassification,
|
TFFlaubertForTokenClassification,
|
||||||
TFFlaubertModel,
|
TFFlaubertModel,
|
||||||
|
TFFlaubertPreTrainedModel,
|
||||||
TFFlaubertWithLMHeadModel,
|
TFFlaubertWithLMHeadModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
|||||||
"FunnelForSequenceClassification",
|
"FunnelForSequenceClassification",
|
||||||
"FunnelForTokenClassification",
|
"FunnelForTokenClassification",
|
||||||
"FunnelModel",
|
"FunnelModel",
|
||||||
|
"FunnelPreTrainedModel",
|
||||||
"load_tf_weights_in_funnel",
|
"load_tf_weights_in_funnel",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -55,6 +56,7 @@ if is_tf_available():
|
|||||||
"TFFunnelForSequenceClassification",
|
"TFFunnelForSequenceClassification",
|
||||||
"TFFunnelForTokenClassification",
|
"TFFunnelForTokenClassification",
|
||||||
"TFFunnelModel",
|
"TFFunnelModel",
|
||||||
|
"TFFunnelPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -76,6 +78,7 @@ if TYPE_CHECKING:
|
|||||||
FunnelForSequenceClassification,
|
FunnelForSequenceClassification,
|
||||||
FunnelForTokenClassification,
|
FunnelForTokenClassification,
|
||||||
FunnelModel,
|
FunnelModel,
|
||||||
|
FunnelPreTrainedModel,
|
||||||
load_tf_weights_in_funnel,
|
load_tf_weights_in_funnel,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,6 +93,7 @@ if TYPE_CHECKING:
|
|||||||
TFFunnelForSequenceClassification,
|
TFFunnelForSequenceClassification,
|
||||||
TFFunnelForTokenClassification,
|
TFFunnelForTokenClassification,
|
||||||
TFFunnelModel,
|
TFFunnelModel,
|
||||||
|
TFFunnelPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -58,7 +58,7 @@ if is_tf_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model"]
|
_import_structure["modeling_flax_gpt2"] = ["FlaxGPT2LMHeadModel", "FlaxGPT2Model", "FlaxGPT2PreTrainedModel"]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
from .configuration_gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config
|
||||||
@@ -90,7 +90,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model
|
from .modeling_flax_gpt2 import FlaxGPT2LMHeadModel, FlaxGPT2Model, FlaxGPT2PreTrainedModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
"LayoutLMForSequenceClassification",
|
"LayoutLMForSequenceClassification",
|
||||||
"LayoutLMForTokenClassification",
|
"LayoutLMForTokenClassification",
|
||||||
"LayoutLMModel",
|
"LayoutLMModel",
|
||||||
|
"LayoutLMPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -66,6 +67,7 @@ if TYPE_CHECKING:
|
|||||||
LayoutLMForSequenceClassification,
|
LayoutLMForSequenceClassification,
|
||||||
LayoutLMForTokenClassification,
|
LayoutLMForTokenClassification,
|
||||||
LayoutLMModel,
|
LayoutLMModel,
|
||||||
|
LayoutLMPreTrainedModel,
|
||||||
)
|
)
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_layoutlm import (
|
from .modeling_tf_layoutlm import (
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
"LongformerForSequenceClassification",
|
"LongformerForSequenceClassification",
|
||||||
"LongformerForTokenClassification",
|
"LongformerForTokenClassification",
|
||||||
"LongformerModel",
|
"LongformerModel",
|
||||||
|
"LongformerPreTrainedModel",
|
||||||
"LongformerSelfAttention",
|
"LongformerSelfAttention",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -50,6 +51,7 @@ if is_tf_available():
|
|||||||
"TFLongformerForSequenceClassification",
|
"TFLongformerForSequenceClassification",
|
||||||
"TFLongformerForTokenClassification",
|
"TFLongformerForTokenClassification",
|
||||||
"TFLongformerModel",
|
"TFLongformerModel",
|
||||||
|
"TFLongformerPreTrainedModel",
|
||||||
"TFLongformerSelfAttention",
|
"TFLongformerSelfAttention",
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -70,6 +72,7 @@ if TYPE_CHECKING:
|
|||||||
LongformerForSequenceClassification,
|
LongformerForSequenceClassification,
|
||||||
LongformerForTokenClassification,
|
LongformerForTokenClassification,
|
||||||
LongformerModel,
|
LongformerModel,
|
||||||
|
LongformerPreTrainedModel,
|
||||||
LongformerSelfAttention,
|
LongformerSelfAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -82,6 +85,7 @@ if TYPE_CHECKING:
|
|||||||
TFLongformerForSequenceClassification,
|
TFLongformerForSequenceClassification,
|
||||||
TFLongformerForTokenClassification,
|
TFLongformerForTokenClassification,
|
||||||
TFLongformerModel,
|
TFLongformerModel,
|
||||||
|
TFLongformerPreTrainedModel,
|
||||||
TFLongformerSelfAttention,
|
TFLongformerSelfAttention,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel"]
|
_import_structure["modeling_tf_marian"] = ["TFMarianModel", "TFMarianMTModel", "TFMarianPreTrainedModel"]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -62,7 +62,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel
|
from .modeling_tf_marian import TFMarianModel, TFMarianMTModel, TFMarianPreTrainedModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -50,7 +50,11 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_mbart"] = ["TFMBartForConditionalGeneration", "TFMBartModel"]
|
_import_structure["modeling_tf_mbart"] = [
|
||||||
|
"TFMBartForConditionalGeneration",
|
||||||
|
"TFMBartModel",
|
||||||
|
"TFMBartPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -76,7 +80,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel
|
from .modeling_tf_mbart import TFMBartForConditionalGeneration, TFMBartModel, TFMBartPreTrainedModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ if is_torch_available():
|
|||||||
"MegatronBertForSequenceClassification",
|
"MegatronBertForSequenceClassification",
|
||||||
"MegatronBertForTokenClassification",
|
"MegatronBertForTokenClassification",
|
||||||
"MegatronBertModel",
|
"MegatronBertModel",
|
||||||
|
"MegatronBertPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -53,6 +54,7 @@ if TYPE_CHECKING:
|
|||||||
MegatronBertForSequenceClassification,
|
MegatronBertForSequenceClassification,
|
||||||
MegatronBertForTokenClassification,
|
MegatronBertForTokenClassification,
|
||||||
MegatronBertModel,
|
MegatronBertModel,
|
||||||
|
MegatronBertPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -46,7 +46,11 @@ if is_torch_available():
|
|||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_pegasus"] = ["TFPegasusForConditionalGeneration", "TFPegasusModel"]
|
_import_structure["modeling_tf_pegasus"] = [
|
||||||
|
"TFPegasusForConditionalGeneration",
|
||||||
|
"TFPegasusModel",
|
||||||
|
"TFPegasusPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -68,7 +72,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel
|
from .modeling_tf_pegasus import TFPegasusForConditionalGeneration, TFPegasusModel, TFPegasusPreTrainedModel
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -28,10 +28,20 @@ _import_structure = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
_import_structure["modeling_rag"] = ["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]
|
_import_structure["modeling_rag"] = [
|
||||||
|
"RagModel",
|
||||||
|
"RagPreTrainedModel",
|
||||||
|
"RagSequenceForGeneration",
|
||||||
|
"RagTokenForGeneration",
|
||||||
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
_import_structure["modeling_tf_rag"] = ["TFRagModel", "TFRagSequenceForGeneration", "TFRagTokenForGeneration"]
|
_import_structure["modeling_tf_rag"] = [
|
||||||
|
"TFRagModel",
|
||||||
|
"TFRagPreTrainedModel",
|
||||||
|
"TFRagSequenceForGeneration",
|
||||||
|
"TFRagTokenForGeneration",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -40,10 +50,15 @@ if TYPE_CHECKING:
|
|||||||
from .tokenization_rag import RagTokenizer
|
from .tokenization_rag import RagTokenizer
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
from .modeling_rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration
|
from .modeling_rag import RagModel, RagPreTrainedModel, RagSequenceForGeneration, RagTokenForGeneration
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
from .modeling_tf_rag import TFRagModel, TFRagSequenceForGeneration, TFRagTokenForGeneration
|
from .modeling_tf_rag import (
|
||||||
|
TFRagModel,
|
||||||
|
TFRagPreTrainedModel,
|
||||||
|
TFRagSequenceForGeneration,
|
||||||
|
TFRagTokenForGeneration,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
import importlib
|
import importlib
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ if is_torch_available():
|
|||||||
"ReformerLayer",
|
"ReformerLayer",
|
||||||
"ReformerModel",
|
"ReformerModel",
|
||||||
"ReformerModelWithLMHead",
|
"ReformerModelWithLMHead",
|
||||||
|
"ReformerPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -63,6 +64,7 @@ if TYPE_CHECKING:
|
|||||||
ReformerLayer,
|
ReformerLayer,
|
||||||
ReformerModel,
|
ReformerModel,
|
||||||
ReformerModelWithLMHead,
|
ReformerModelWithLMHead,
|
||||||
|
ReformerPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ if is_torch_available():
|
|||||||
"RobertaForSequenceClassification",
|
"RobertaForSequenceClassification",
|
||||||
"RobertaForTokenClassification",
|
"RobertaForTokenClassification",
|
||||||
"RobertaModel",
|
"RobertaModel",
|
||||||
|
"RobertaPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -89,6 +90,7 @@ if TYPE_CHECKING:
|
|||||||
RobertaForSequenceClassification,
|
RobertaForSequenceClassification,
|
||||||
RobertaForTokenClassification,
|
RobertaForTokenClassification,
|
||||||
RobertaModel,
|
RobertaModel,
|
||||||
|
RobertaPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ if is_torch_available():
|
|||||||
"TapasForQuestionAnswering",
|
"TapasForQuestionAnswering",
|
||||||
"TapasForSequenceClassification",
|
"TapasForSequenceClassification",
|
||||||
"TapasModel",
|
"TapasModel",
|
||||||
|
"TapasPreTrainedModel",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -47,6 +48,7 @@ if TYPE_CHECKING:
|
|||||||
TapasForQuestionAnswering,
|
TapasForQuestionAnswering,
|
||||||
TapasForSequenceClassification,
|
TapasForSequenceClassification,
|
||||||
TapasModel,
|
TapasModel,
|
||||||
|
TapasPreTrainedModel,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -37,7 +37,11 @@ if is_torch_available():
|
|||||||
|
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
_import_structure["modeling_flax_vit"] = ["FlaxViTForImageClassification", "FlaxViTModel"]
|
_import_structure["modeling_flax_vit"] = [
|
||||||
|
"FlaxViTForImageClassification",
|
||||||
|
"FlaxViTModel",
|
||||||
|
"FlaxViTPreTrainedModel",
|
||||||
|
]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
from .configuration_vit import VIT_PRETRAINED_CONFIG_ARCHIVE_MAP, ViTConfig
|
||||||
@@ -54,7 +58,7 @@ if TYPE_CHECKING:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_flax_available():
|
if is_flax_available():
|
||||||
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel
|
from .modeling_flax_vit import FlaxViTForImageClassification, FlaxViTModel, FlaxViTPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -244,6 +244,15 @@ class FlaxBartModel:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxBartPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxBertForMaskedLM:
|
class FlaxBertForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -412,6 +421,15 @@ class FlaxCLIPTextModel:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxCLIPTextPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxCLIPVisionModel:
|
class FlaxCLIPVisionModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -421,6 +439,15 @@ class FlaxCLIPVisionModel:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxCLIPVisionPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxElectraForMaskedLM:
|
class FlaxElectraForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -507,6 +534,15 @@ class FlaxGPT2Model:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxGPT2PreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxRobertaForMaskedLM:
|
class FlaxRobertaForMaskedLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -588,6 +624,15 @@ class FlaxT5Model:
|
|||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxT5PreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
class FlaxViTForImageClassification:
|
class FlaxViTForImageClassification:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["flax"])
|
requires_backends(self, ["flax"])
|
||||||
@@ -600,3 +645,12 @@ class FlaxViTModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
requires_backends(cls, ["flax"])
|
requires_backends(cls, ["flax"])
|
||||||
|
|
||||||
|
|
||||||
|
class FlaxViTPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["flax"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["flax"])
|
||||||
|
|||||||
@@ -692,6 +692,15 @@ class BertGenerationEncoder:
|
|||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BertGenerationPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_bert_generation(*args, **kwargs):
|
def load_tf_weights_in_bert_generation(*args, **kwargs):
|
||||||
requires_backends(load_tf_weights_in_bert_generation, ["torch"])
|
requires_backends(load_tf_weights_in_bert_generation, ["torch"])
|
||||||
|
|
||||||
@@ -833,6 +842,15 @@ class BigBirdPegasusModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BigBirdPegasusPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -863,6 +881,15 @@ class BlenderbotModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -893,6 +920,15 @@ class BlenderbotSmallModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class BlenderbotSmallPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1610,6 +1646,15 @@ class FunnelModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class FunnelPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
def load_tf_weights_in_funnel(*args, **kwargs):
|
def load_tf_weights_in_funnel(*args, **kwargs):
|
||||||
requires_backends(load_tf_weights_in_funnel, ["torch"])
|
requires_backends(load_tf_weights_in_funnel, ["torch"])
|
||||||
|
|
||||||
@@ -1840,6 +1885,15 @@ class LayoutLMModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LayoutLMPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
LED_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1879,6 +1933,15 @@ class LEDModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LEDPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1936,6 +1999,15 @@ class LongformerModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class LongformerPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class LongformerSelfAttention:
|
class LongformerSelfAttention:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -2045,6 +2117,15 @@ class M2M100Model:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class M2M100PreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MarianForCausalLM:
|
class MarianForCausalLM:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -2117,6 +2198,15 @@ class MBartModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MBartPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
MEGATRON_BERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2193,6 +2283,15 @@ class MegatronBertModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class MegatronBertPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class MMBTForClassification:
|
class MMBTForClassification:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -2474,6 +2573,15 @@ class PegasusModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class PegasusPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2532,6 +2640,15 @@ class RagModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class RagPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
class RagSequenceForGeneration:
|
class RagSequenceForGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["torch"])
|
requires_backends(self, ["torch"])
|
||||||
@@ -2600,6 +2717,15 @@ class ReformerModelWithLMHead:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class ReformerPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2687,6 +2813,15 @@ class RobertaModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class RobertaPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
ROFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2792,6 +2927,15 @@ class Speech2TextModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class Speech2TextPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -2945,6 +3089,15 @@ class TapasModel:
|
|||||||
requires_backends(cls, ["torch"])
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
|
class TapasPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["torch"])
|
||||||
|
|
||||||
|
|
||||||
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -431,6 +431,15 @@ class TFBlenderbotModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFBlenderbotPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFBlenderbotSmallForConditionalGeneration:
|
class TFBlenderbotSmallForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
@@ -449,6 +458,15 @@ class TFBlenderbotSmallModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFBlenderbotSmallPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -845,6 +863,15 @@ class TFFlaubertModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFFlaubertPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFFlaubertWithLMHeadModel:
|
class TFFlaubertWithLMHeadModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
@@ -925,6 +952,15 @@ class TFFunnelModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFFunnelPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1062,6 +1098,15 @@ class TFLongformerModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFLongformerPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFLongformerSelfAttention:
|
class TFLongformerSelfAttention:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
@@ -1121,6 +1166,15 @@ class TFMarianMTModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFMarianPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFMBartForConditionalGeneration:
|
class TFMBartForConditionalGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
@@ -1139,6 +1193,15 @@ class TFMBartModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFMBartPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
|
||||||
|
|
||||||
|
|
||||||
@@ -1389,6 +1452,15 @@ class TFPegasusModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFPegasusPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFRagModel:
|
class TFRagModel:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
@@ -1398,6 +1470,15 @@ class TFRagModel:
|
|||||||
requires_backends(cls, ["tf"])
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
|
class TFRagPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["tf"])
|
||||||
|
|
||||||
|
|
||||||
class TFRagSequenceForGeneration:
|
class TFRagSequenceForGeneration:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_backends(self, ["tf"])
|
requires_backends(self, ["tf"])
|
||||||
|
|||||||
@@ -30,3 +30,12 @@ class DetrModel:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, *args, **kwargs):
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
requires_backends(cls, ["timm", "vision"])
|
requires_backends(cls, ["timm", "vision"])
|
||||||
|
|
||||||
|
|
||||||
|
class DetrPreTrainedModel:
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
requires_backends(self, ["timm", "vision"])
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, *args, **kwargs):
|
||||||
|
requires_backends(cls, ["timm", "vision"])
|
||||||
|
|||||||
@@ -52,6 +52,7 @@
|
|||||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
"{{cookiecutter.camelcase_modelname}}Model",
|
"{{cookiecutter.camelcase_modelname}}Model",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
@@ -120,6 +121,7 @@
|
|||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
|
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||||
)
|
)
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
# End.
|
# End.
|
||||||
|
|||||||
@@ -31,9 +31,16 @@ PATH_TO_TRANSFORMERS = "src/transformers"
|
|||||||
PATH_TO_TESTS = "tests"
|
PATH_TO_TESTS = "tests"
|
||||||
PATH_TO_DOC = "docs/source"
|
PATH_TO_DOC = "docs/source"
|
||||||
|
|
||||||
|
# Update this list with models that are supposed to be private.
|
||||||
|
PRIVATE_MODELS = [
|
||||||
|
"DPRSpanPredictor",
|
||||||
|
"T5Stack",
|
||||||
|
"TFDPRSpanPredictor",
|
||||||
|
]
|
||||||
|
|
||||||
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
# Update this list for models that are not tested with a comment explaining the reason it should not be.
|
||||||
# Being in this list is an exception and should **not** be the rule.
|
# Being in this list is an exception and should **not** be the rule.
|
||||||
IGNORE_NON_TESTED = [
|
IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
|
||||||
# models to ignore for not tested
|
# models to ignore for not tested
|
||||||
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
"BigBirdPegasusEncoder", # Building part of bigger (tested) model.
|
||||||
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
"BigBirdPegasusDecoder", # Building part of bigger (tested) model.
|
||||||
@@ -63,12 +70,9 @@ IGNORE_NON_TESTED = [
|
|||||||
"PegasusEncoder", # Building part of bigger (tested) model.
|
"PegasusEncoder", # Building part of bigger (tested) model.
|
||||||
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
"PegasusDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"DPREncoder", # Building part of bigger (tested) model.
|
"DPREncoder", # Building part of bigger (tested) model.
|
||||||
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
|
||||||
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||||
"T5Stack", # Building part of bigger (tested) model.
|
|
||||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||||
"TFDPRSpanPredictor", # Building part of bigger (tested) model.
|
|
||||||
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
"TFElectraMainLayer", # Building part of bigger (tested) model (should it be a TFPreTrainedModel ?)
|
||||||
"TFRobertaForMultipleChoice", # TODO: fix
|
"TFRobertaForMultipleChoice", # TODO: fix
|
||||||
"SeparableConv1D", # Building part of bigger (tested) model.
|
"SeparableConv1D", # Building part of bigger (tested) model.
|
||||||
@@ -92,7 +96,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
|
|||||||
|
|
||||||
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
# Update this list for models that are not in any of the auto MODEL_XXX_MAPPING. Being in this list is an exception and
|
||||||
# should **not** be the rule.
|
# should **not** be the rule.
|
||||||
IGNORE_NON_AUTO_CONFIGURED = [
|
IGNORE_NON_AUTO_CONFIGURED = PRIVATE_MODELS.copy() + [
|
||||||
# models to ignore for model xxx mapping
|
# models to ignore for model xxx mapping
|
||||||
"CLIPTextModel",
|
"CLIPTextModel",
|
||||||
"CLIPVisionModel",
|
"CLIPVisionModel",
|
||||||
@@ -100,7 +104,6 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"FlaxCLIPVisionModel",
|
"FlaxCLIPVisionModel",
|
||||||
"DetrForSegmentation",
|
"DetrForSegmentation",
|
||||||
"DPRReader",
|
"DPRReader",
|
||||||
"DPRSpanPredictor",
|
|
||||||
"FlaubertForQuestionAnswering",
|
"FlaubertForQuestionAnswering",
|
||||||
"GPT2DoubleHeadsModel",
|
"GPT2DoubleHeadsModel",
|
||||||
"LukeForEntityClassification",
|
"LukeForEntityClassification",
|
||||||
@@ -110,9 +113,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"RagModel",
|
"RagModel",
|
||||||
"RagSequenceForGeneration",
|
"RagSequenceForGeneration",
|
||||||
"RagTokenForGeneration",
|
"RagTokenForGeneration",
|
||||||
"T5Stack",
|
|
||||||
"TFDPRReader",
|
"TFDPRReader",
|
||||||
"TFDPRSpanPredictor",
|
|
||||||
"TFGPT2DoubleHeadsModel",
|
"TFGPT2DoubleHeadsModel",
|
||||||
"TFOpenAIGPTDoubleHeadsModel",
|
"TFOpenAIGPTDoubleHeadsModel",
|
||||||
"TFRagModel",
|
"TFRagModel",
|
||||||
@@ -173,12 +174,12 @@ def get_model_modules():
|
|||||||
return modules
|
return modules
|
||||||
|
|
||||||
|
|
||||||
def get_models(module):
|
def get_models(module, include_pretrained=False):
|
||||||
"""Get the objects in module that are models."""
|
"""Get the objects in module that are models."""
|
||||||
models = []
|
models = []
|
||||||
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
model_classes = (transformers.PreTrainedModel, transformers.TFPreTrainedModel, transformers.FlaxPreTrainedModel)
|
||||||
for attr_name in dir(module):
|
for attr_name in dir(module):
|
||||||
if "Pretrained" in attr_name or "PreTrained" in attr_name:
|
if not include_pretrained and ("Pretrained" in attr_name or "PreTrained" in attr_name):
|
||||||
continue
|
continue
|
||||||
attr = getattr(module, attr_name)
|
attr = getattr(module, attr_name)
|
||||||
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
if isinstance(attr, type) and issubclass(attr, model_classes) and attr.__module__ == module.__name__:
|
||||||
@@ -186,6 +187,36 @@ def get_models(module):
|
|||||||
return models
|
return models
|
||||||
|
|
||||||
|
|
||||||
|
def is_a_private_model(model):
|
||||||
|
"""Returns True if the model should not be in the main init."""
|
||||||
|
if model in PRIVATE_MODELS:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Wrapper, Encoder and Decoder are all privates
|
||||||
|
if model.endswith("Wrapper"):
|
||||||
|
return True
|
||||||
|
if model.endswith("Encoder"):
|
||||||
|
return True
|
||||||
|
if model.endswith("Decoder"):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_models_are_in_init():
|
||||||
|
"""Checks all models defined in the library are in the main init."""
|
||||||
|
models_not_in_init = []
|
||||||
|
dir_transformers = dir(transformers)
|
||||||
|
for module in get_model_modules():
|
||||||
|
models_not_in_init += [
|
||||||
|
model[0] for model in get_models(module, include_pretrained=True) if model[0] not in dir_transformers
|
||||||
|
]
|
||||||
|
|
||||||
|
# Remove private models
|
||||||
|
models_not_in_init = [model for model in models_not_in_init if not is_a_private_model(model)]
|
||||||
|
if len(models_not_in_init) > 0:
|
||||||
|
raise Exception(f"The following models should be in the main init: {','.join(models_not_in_init)}.")
|
||||||
|
|
||||||
|
|
||||||
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
# If some test_modeling files should be ignored when checking models are all tested, they should be added in the
|
||||||
# nested list _ignore_files of this function.
|
# nested list _ignore_files of this function.
|
||||||
def get_model_test_files():
|
def get_model_test_files():
|
||||||
@@ -229,6 +260,7 @@ def find_tested_models(test_file):
|
|||||||
|
|
||||||
def check_models_are_tested(module, test_file):
|
def check_models_are_tested(module, test_file):
|
||||||
"""Check models defined in module are tested in test_file."""
|
"""Check models defined in module are tested in test_file."""
|
||||||
|
# XxxPreTrainedModel are not tested
|
||||||
defined_models = get_models(module)
|
defined_models = get_models(module)
|
||||||
tested_models = find_tested_models(test_file)
|
tested_models = find_tested_models(test_file)
|
||||||
if tested_models is None:
|
if tested_models is None:
|
||||||
@@ -515,6 +547,8 @@ def check_all_objects_are_documented():
|
|||||||
|
|
||||||
def check_repo_quality():
|
def check_repo_quality():
|
||||||
"""Check all models are properly tested and documented."""
|
"""Check all models are properly tested and documented."""
|
||||||
|
print("Checking all models are public.")
|
||||||
|
check_models_are_in_init()
|
||||||
print("Checking all models are properly tested.")
|
print("Checking all models are properly tested.")
|
||||||
check_all_decorator_order()
|
check_all_decorator_order()
|
||||||
check_all_models_are_tested()
|
check_all_models_are_tested()
|
||||||
|
|||||||
Reference in New Issue
Block a user