From 758ed3332b219dd3529a1d3639fa30aa4954e0f3 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 7 Jan 2021 09:36:14 -0500 Subject: [PATCH] Transformers fast import part 2 (#9446) * Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik * Add documentation * Styling Co-authored-by: LysandreJik --- src/transformers/__init__.py | 2874 ++++++++++++----- src/transformers/benchmark/benchmark_utils.py | 5 +- src/transformers/commands/add_new_model.py | 3 +- src/transformers/commands/convert.py | 21 +- src/transformers/commands/download.py | 4 +- src/transformers/commands/env.py | 6 +- src/transformers/commands/lfs.py | 2 +- src/transformers/commands/run.py | 5 +- src/transformers/commands/serving.py | 6 +- src/transformers/commands/train.py | 8 +- src/transformers/commands/transformers_cli.py | 16 +- src/transformers/commands/user.py | 5 +- src/transformers/convert_graph_to_onnx.py | 7 +- .../convert_pytorch_checkpoint_to_tf2.py | 8 +- ...ert_slow_tokenizers_checkpoints_to_fast.py | 5 +- ...nvert_tf_hub_seq_to_seq_bert_to_pytorch.py | 2 +- .../data/metrics/squad_metrics.py | 3 +- .../data/test_generation_utils.py | 7 +- src/transformers/file_utils.py | 48 +- src/transformers/integrations.py | 2 +- src/transformers/models/__init__.py | 68 + ...lbert_original_tf_checkpoint_to_pytorch.py | 4 +- ..._original_pytorch_checkpoint_to_pytorch.py | 12 +- .../models/bart/tokenization_bart.py | 3 +- .../models/bart/tokenization_bart_fast.py | 3 +- ...bert_original_tf2_checkpoint_to_pytorch.py | 4 +- ..._bert_original_tf_checkpoint_to_pytorch.py | 4 +- ..._bert_pytorch_checkpoint_to_original_tf.py | 2 +- ..._original_pytorch_checkpoint_to_pytorch.py | 4 +- ..._original_pytorch_checkpoint_to_pytorch.py | 2 +- ...vert_dpr_original_checkpoint_to_pytorch.py | 3 +- ...ectra_original_tf_checkpoint_to_pytorch.py | 4 +- ..._original_pytorch_checkpoint_to_pytorch.py | 7 +- ...unnel_original_tf_checkpoint_to_pytorch.py | 2 +- ..._gpt2_original_tf_checkpoint_to_pytorch.py | 5 +- ...r_original_pytorch_lightning_to_pytorch.py | 2 +- ...xmert_original_tf_checkpoint_to_pytorch.py | 2 +- .../convert_marian_tatoeba_to_pytorch.py | 2 +- .../marian/convert_marian_to_pytorch.py | 4 +- ...rt_mbart_original_checkpoint_to_pytorch.py | 4 +- ...ebert_original_tf_checkpoint_to_pytorch.py | 4 +- ...penai_original_tf_checkpoint_to_pytorch.py | 5 +- .../pegasus/convert_pegasus_tf_to_pytorch.py | 4 +- ..._original_pytorch_checkpoint_to_pytorch.py | 4 +- ...ert_reformer_trax_checkpoint_to_pytorch.py | 4 +- ..._original_pytorch_checkpoint_to_pytorch.py | 16 +- ...rt_t5_original_tf_checkpoint_to_pytorch.py | 4 +- src/transformers/models/t5/modeling_tf_t5.py | 3 +- ...tapas_original_tf_checkpoint_to_pytorch.py | 6 +- .../models/tapas/tokenization_tapas.py | 4 +- ...fo_xl_original_tf_checkpoint_to_pytorch.py | 15 +- ..._original_pytorch_checkpoint_to_pytorch.py | 6 +- ...xlnet_original_tf_checkpoint_to_pytorch.py | 7 +- ...ce_{{cookiecutter.lowercase_modelname}}.py | 162 +- utils/check_dummies.py | 374 +-- utils/check_repo.py | 7 +- 56 files changed, 2426 insertions(+), 1377 deletions(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 2874b08931..e7d566f345 100755 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -16,6 +16,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +# When adding a new object to this init, remember to add it twice: once inside the `_import_structure` dictionary and +# once inside the `if TYPE_CHECKING` branch. The `TYPE_CHECKING` should have import statements as usual, but they are +# only there for type checking. The `_import_structure` is a dictionary submodule to list of object names, and is used +# to defer the actual importing for when the objects are requested. This way `import transformers` provides the names +# in the namespace without actually importing anything (and especially none of the backends). + __version__ = "4.2.0dev0" # Work around to update TensorFlow's absl.logging threshold which alters the @@ -31,963 +37,2061 @@ else: absl.logging.set_stderrthreshold("info") absl.logging._warn_preinit_stderr = False +from typing import TYPE_CHECKING + +# Check the dependencies satisfy the minimal versions required. from . import dependency_versions_check - -# Configuration -from .configuration_utils import PretrainedConfig - -# Data -from .data import ( - DataProcessor, - InputExample, - InputFeatures, - SingleSentenceClassificationProcessor, - SquadExample, - SquadFeatures, - SquadV1Processor, - SquadV2Processor, - glue_compute_metrics, - glue_convert_examples_to_features, - glue_output_modes, - glue_processors, - glue_tasks_num_labels, - squad_convert_examples_to_features, - xnli_compute_metrics, - xnli_output_modes, - xnli_processors, - xnli_tasks_num_labels, -) - -# Files and general utilities from .file_utils import ( - CONFIG_NAME, - MODEL_CARD_NAME, - PYTORCH_PRETRAINED_BERT_CACHE, - PYTORCH_TRANSFORMERS_CACHE, - SPIECE_UNDERLINE, - TF2_WEIGHTS_NAME, - TF_WEIGHTS_NAME, - TRANSFORMERS_CACHE, - WEIGHTS_NAME, - add_end_docstrings, - add_start_docstrings, - cached_path, - is_apex_available, - is_datasets_available, - is_faiss_available, + _BaseLazyModule, is_flax_available, - is_psutil_available, - is_py3nvml_available, is_sentencepiece_available, - is_sklearn_available, is_tf_available, is_tokenizers_available, is_torch_available, - is_torch_tpu_available, ) -from .hf_argparser import HfArgumentParser - -# Model Cards -from .modelcard import ModelCard - -# TF 2.0 <=> PyTorch conversion utilities -from .modeling_tf_pytorch_utils import ( - convert_tf_weight_name_to_pt_weight_name, - load_pytorch_checkpoint_in_tf2_model, - load_pytorch_model_in_tf2_model, - load_pytorch_weights_in_tf2_model, - load_tf2_checkpoint_in_pytorch_model, - load_tf2_model_in_pytorch_model, - load_tf2_weights_in_pytorch_model, -) -from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig -from .models.auto import ( - ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, - CONFIG_MAPPING, - MODEL_NAMES_MAPPING, - TOKENIZER_MAPPING, - AutoConfig, - AutoTokenizer, -) -from .models.bart import BartConfig, BartTokenizer -from .models.bert import ( - BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, - BasicTokenizer, - BertConfig, - BertTokenizer, - WordpieceTokenizer, -) -from .models.bert_generation import BertGenerationConfig -from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer -from .models.bertweet import BertweetTokenizer -from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer -from .models.blenderbot_small import ( - BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, - BlenderbotSmallConfig, - BlenderbotSmallTokenizer, -) -from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig -from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer -from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer -from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer -from .models.dpr import ( - DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, - DPRConfig, - DPRContextEncoderTokenizer, - DPRQuestionEncoderTokenizer, - DPRReaderOutput, - DPRReaderTokenizer, -) -from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer -from .models.encoder_decoder import EncoderDecoderConfig -from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer -from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer -from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer -from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer -from .models.herbert import HerbertTokenizer -from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer -from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer -from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer -from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer -from .models.marian import MarianConfig -from .models.mbart import MBartConfig -from .models.mmbt import MMBTConfig -from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer -from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer -from .models.mt5 import MT5Config -from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer -from .models.pegasus import PegasusConfig -from .models.phobert import PhobertTokenizer -from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer -from .models.rag import RagConfig, RagRetriever, RagTokenizer -from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig -from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer -from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer -from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer -from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config -from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer -from .models.transfo_xl import ( - TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, - TransfoXLConfig, - TransfoXLCorpus, - TransfoXLTokenizer, -) -from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer -from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig -from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig -from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig - -# Pipelines -from .pipelines import ( - Conversation, - ConversationalPipeline, - CsvPipelineDataFormat, - FeatureExtractionPipeline, - FillMaskPipeline, - JsonPipelineDataFormat, - NerPipeline, - PipedPipelineDataFormat, - Pipeline, - PipelineDataFormat, - QuestionAnsweringPipeline, - SummarizationPipeline, - TableQuestionAnsweringPipeline, - Text2TextGenerationPipeline, - TextClassificationPipeline, - TextGenerationPipeline, - TokenClassificationPipeline, - TranslationPipeline, - ZeroShotClassificationPipeline, - pipeline, -) - -# Tokenization -from .tokenization_utils import PreTrainedTokenizer -from .tokenization_utils_base import ( - AddedToken, - BatchEncoding, - CharSpan, - PreTrainedTokenizerBase, - SpecialTokensMixin, - TensorType, - TokenSpan, -) - - -# Integrations: this needs to come before other ml imports -# in order to allow any 3rd-party code to initialize properly -from .integrations import ( # isort:skip - is_comet_available, - is_optuna_available, - is_ray_available, - is_ray_tune_available, - is_tensorboard_available, - is_wandb_available, -) - - -if is_sentencepiece_available(): - from .models.albert import AlbertTokenizer - from .models.barthez import BarthezTokenizer - from .models.bert_generation import BertGenerationTokenizer - from .models.camembert import CamembertTokenizer - from .models.marian import MarianTokenizer - from .models.mbart import MBartTokenizer - from .models.mt5 import MT5Tokenizer - from .models.pegasus import PegasusTokenizer - from .models.reformer import ReformerTokenizer - from .models.t5 import T5Tokenizer - from .models.xlm_prophetnet import XLMProphetNetTokenizer - from .models.xlm_roberta import XLMRobertaTokenizer - from .models.xlnet import XLNetTokenizer -else: - from .utils.dummy_sentencepiece_objects import * - -if is_tokenizers_available(): - from .models.albert import AlbertTokenizerFast - from .models.bart import BartTokenizerFast - from .models.barthez import BarthezTokenizerFast - from .models.bert import BertTokenizerFast - from .models.camembert import CamembertTokenizerFast - from .models.distilbert import DistilBertTokenizerFast - from .models.dpr import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, DPRReaderTokenizerFast - from .models.electra import ElectraTokenizerFast - from .models.funnel import FunnelTokenizerFast - from .models.gpt2 import GPT2TokenizerFast - from .models.herbert import HerbertTokenizerFast - from .models.layoutlm import LayoutLMTokenizerFast - from .models.led import LEDTokenizerFast - from .models.longformer import LongformerTokenizerFast - from .models.lxmert import LxmertTokenizerFast - from .models.mbart import MBartTokenizerFast - from .models.mobilebert import MobileBertTokenizerFast - from .models.mpnet import MPNetTokenizerFast - from .models.mt5 import MT5TokenizerFast - from .models.openai import OpenAIGPTTokenizerFast - from .models.pegasus import PegasusTokenizerFast - from .models.reformer import ReformerTokenizerFast - from .models.retribert import RetriBertTokenizerFast - from .models.roberta import RobertaTokenizerFast - from .models.squeezebert import SqueezeBertTokenizerFast - from .models.t5 import T5TokenizerFast - from .models.xlm_roberta import XLMRobertaTokenizerFast - from .models.xlnet import XLNetTokenizerFast - from .tokenization_utils_fast import PreTrainedTokenizerFast - - if is_sentencepiece_available(): - from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer -else: - from .utils.dummy_tokenizers_objects import * - -# Trainer -from .trainer_callback import ( - DefaultFlowCallback, - EarlyStoppingCallback, - PrinterCallback, - ProgressCallback, - TrainerCallback, - TrainerControl, - TrainerState, -) -from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed -from .training_args import TrainingArguments -from .training_args_seq2seq import Seq2SeqTrainingArguments -from .training_args_tf import TFTrainingArguments from .utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Modeling +# Base objects, independent of any specific backend +_import_structure = { + "configuration_utils": ["PretrainedConfig"], + "data": [ + "DataProcessor", + "InputExample", + "InputFeatures", + "SingleSentenceClassificationProcessor", + "SquadExample", + "SquadFeatures", + "SquadV1Processor", + "SquadV2Processor", + "glue_compute_metrics", + "glue_convert_examples_to_features", + "glue_output_modes", + "glue_processors", + "glue_tasks_num_labels", + "squad_convert_examples_to_features", + "xnli_compute_metrics", + "xnli_output_modes", + "xnli_processors", + "xnli_tasks_num_labels", + ], + "file_utils": [ + "CONFIG_NAME", + "MODEL_CARD_NAME", + "PYTORCH_PRETRAINED_BERT_CACHE", + "PYTORCH_TRANSFORMERS_CACHE", + "SPIECE_UNDERLINE", + "TF2_WEIGHTS_NAME", + "TF_WEIGHTS_NAME", + "TRANSFORMERS_CACHE", + "WEIGHTS_NAME", + "add_end_docstrings", + "add_start_docstrings", + "cached_path", + "is_apex_available", + "is_datasets_available", + "is_faiss_available", + "is_flax_available", + "is_psutil_available", + "is_py3nvml_available", + "is_sentencepiece_available", + "is_sklearn_available", + "is_tf_available", + "is_tokenizers_available", + "is_torch_available", + "is_torch_tpu_available", + ], + "hf_argparser": ["HfArgumentParser"], + "integrations": [ + "is_comet_available", + "is_optuna_available", + "is_ray_available", + "is_ray_tune_available", + "is_tensorboard_available", + "is_wandb_available", + ], + "modelcard": ["ModelCard"], + "modeling_tf_pytorch_utils": [ + "convert_tf_weight_name_to_pt_weight_name", + "load_pytorch_checkpoint_in_tf2_model", + "load_pytorch_model_in_tf2_model", + "load_pytorch_weights_in_tf2_model", + "load_tf2_checkpoint_in_pytorch_model", + "load_tf2_model_in_pytorch_model", + "load_tf2_weights_in_pytorch_model", + ], + "models": [], + # Models + "models.albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"], + "models.auto": [ + "ALL_PRETRAINED_CONFIG_ARCHIVE_MAP", + "CONFIG_MAPPING", + "MODEL_NAMES_MAPPING", + "TOKENIZER_MAPPING", + "AutoConfig", + "AutoTokenizer", + ], + "models.bart": ["BartConfig", "BartTokenizer"], + "models.barthez": [], + "models.bert": [ + "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BasicTokenizer", + "BertConfig", + "BertTokenizer", + "WordpieceTokenizer", + ], + "models.bert_generation": ["BertGenerationConfig"], + "models.bert_japanese": ["BertJapaneseTokenizer", "CharacterTokenizer", "MecabTokenizer"], + "models.bertweet": ["BertweetTokenizer"], + "models.blenderbot": ["BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BlenderbotConfig", "BlenderbotTokenizer"], + "models.blenderbot_small": [ + "BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP", + "BlenderbotSmallConfig", + "BlenderbotSmallTokenizer", + ], + "models.camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], + "models.ctrl": ["CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP", "CTRLConfig", "CTRLTokenizer"], + "models.deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaTokenizer"], + "models.distilbert": ["DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "DistilBertConfig", "DistilBertTokenizer"], + "models.dpr": [ + "DPR_PRETRAINED_CONFIG_ARCHIVE_MAP", + "DPRConfig", + "DPRContextEncoderTokenizer", + "DPRQuestionEncoderTokenizer", + "DPRReaderOutput", + "DPRReaderTokenizer", + ], + "models.electra": ["ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP", "ElectraConfig", "ElectraTokenizer"], + "models.encoder_decoder": ["EncoderDecoderConfig"], + "models.flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertTokenizer"], + "models.fsmt": ["FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FSMTConfig", "FSMTTokenizer"], + "models.funnel": ["FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP", "FunnelConfig", "FunnelTokenizer"], + "models.gpt2": ["GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP", "GPT2Config", "GPT2Tokenizer"], + "models.herbert": ["HerbertTokenizer"], + "models.layoutlm": ["LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "LayoutLMConfig", "LayoutLMTokenizer"], + "models.led": ["LED_PRETRAINED_CONFIG_ARCHIVE_MAP", "LEDConfig", "LEDTokenizer"], + "models.longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig", "LongformerTokenizer"], + "models.lxmert": ["LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "LxmertConfig", "LxmertTokenizer"], + "models.marian": ["MarianConfig"], + "models.mbart": ["MBartConfig"], + "models.mmbt": ["MMBTConfig"], + "models.mobilebert": ["MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "MobileBertConfig", "MobileBertTokenizer"], + "models.mpnet": ["MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "MPNetConfig", "MPNetTokenizer"], + "models.mt5": ["MT5Config"], + "models.openai": ["OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP", "OpenAIGPTConfig", "OpenAIGPTTokenizer"], + "models.pegasus": ["PegasusConfig"], + "models.phobert": ["PhobertTokenizer"], + "models.prophetnet": ["PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "ProphetNetConfig", "ProphetNetTokenizer"], + "models.rag": ["RagConfig", "RagRetriever", "RagTokenizer"], + "models.reformer": ["REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "ReformerConfig"], + "models.retribert": ["RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "RetriBertConfig", "RetriBertTokenizer"], + "models.roberta": ["ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "RobertaConfig", "RobertaTokenizer"], + "models.squeezebert": ["SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "SqueezeBertConfig", "SqueezeBertTokenizer"], + "models.t5": ["T5_PRETRAINED_CONFIG_ARCHIVE_MAP", "T5Config"], + "models.tapas": ["TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP", "TapasConfig", "TapasTokenizer"], + "models.transfo_xl": [ + "TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP", + "TransfoXLConfig", + "TransfoXLCorpus", + "TransfoXLTokenizer", + ], + "models.xlm": ["XLM_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMConfig", "XLMTokenizer"], + "models.xlm_prophetnet": ["XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMProphetNetConfig"], + "models.xlm_roberta": ["XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLMRobertaConfig"], + "models.xlnet": ["XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP", "XLNetConfig"], + "pipelines": [ + "Conversation", + "ConversationalPipeline", + "CsvPipelineDataFormat", + "FeatureExtractionPipeline", + "FillMaskPipeline", + "JsonPipelineDataFormat", + "NerPipeline", + "PipedPipelineDataFormat", + "Pipeline", + "PipelineDataFormat", + "QuestionAnsweringPipeline", + "SummarizationPipeline", + "TableQuestionAnsweringPipeline", + "Text2TextGenerationPipeline", + "TextClassificationPipeline", + "TextGenerationPipeline", + "TokenClassificationPipeline", + "TranslationPipeline", + "ZeroShotClassificationPipeline", + "pipeline", + ], + "tokenization_utils": ["PreTrainedTokenizer"], + "tokenization_utils_base": [ + "AddedToken", + "BatchEncoding", + "CharSpan", + "PreTrainedTokenizerBase", + "SpecialTokensMixin", + "TensorType", + "TokenSpan", + ], + "trainer_callback": [ + "DefaultFlowCallback", + "EarlyStoppingCallback", + "PrinterCallback", + "ProgressCallback", + "TrainerCallback", + "TrainerControl", + "TrainerState", + ], + "trainer_utils": ["EvalPrediction", "EvaluationStrategy", "SchedulerType", "set_seed"], + "training_args": ["TrainingArguments"], + "training_args_seq2seq": ["Seq2SeqTrainingArguments"], + "training_args_tf": ["TFTrainingArguments"], + "utils": ["logging"], +} + +# sentencepiece-backed objects +if is_sentencepiece_available(): + _import_structure["models.albert"].append("AlbertTokenizer") + _import_structure["models.barthez"].append("BarthezTokenizer") + _import_structure["models.bert_generation"].append("BertGenerationTokenizer") + _import_structure["models.camembert"].append("CamembertTokenizer") + _import_structure["models.marian"].append("MarianTokenizer") + _import_structure["models.mbart"].append("MBartTokenizer") + _import_structure["models.mt5"].append("MT5Tokenizer") + _import_structure["models.pegasus"].append("PegasusTokenizer") + _import_structure["models.reformer"].append("ReformerTokenizer") + _import_structure["models.t5"].append("T5Tokenizer") + _import_structure["models.xlm_prophetnet"].append("XLMProphetNetTokenizer") + _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizer") + _import_structure["models.xlnet"].append("XLNetTokenizer") +else: + from .utils import dummy_sentencepiece_objects + + _import_structure["utils.dummy_sentencepiece_objects"] = [ + name for name in dir(dummy_sentencepiece_objects) if not name.startswith("_") + ] + +# tokenziers-backed objects +if is_tokenizers_available(): + # Fast tokenizers + _import_structure["models.albert"].append("AlbertTokenizerFast") + _import_structure["models.bart"].append("BartTokenizerFast") + _import_structure["models.barthez"].append("BarthezTokenizerFast") + _import_structure["models.bert"].append("BertTokenizerFast") + _import_structure["models.camembert"].append("CamembertTokenizerFast") + _import_structure["models.distilbert"].append("DistilBertTokenizerFast") + _import_structure["models.dpr"].extend( + ["DPRContextEncoderTokenizerFast", "DPRQuestionEncoderTokenizerFast", "DPRReaderTokenizerFast"] + ) + _import_structure["models.electra"].append("ElectraTokenizerFast") + _import_structure["models.funnel"].append("FunnelTokenizerFast") + _import_structure["models.gpt2"].append("GPT2TokenizerFast") + _import_structure["models.herbert"].append("HerbertTokenizerFast") + _import_structure["models.layoutlm"].append("LayoutLMTokenizerFast") + _import_structure["models.led"].append("LEDTokenizerFast") + _import_structure["models.longformer"].append("LongformerTokenizerFast") + _import_structure["models.lxmert"].append("LxmertTokenizerFast") + _import_structure["models.mbart"].append("MBartTokenizerFast") + _import_structure["models.mobilebert"].append("MobileBertTokenizerFast") + _import_structure["models.mpnet"].append("MPNetTokenizerFast") + _import_structure["models.mt5"].append("MT5TokenizerFast") + _import_structure["models.openai"].append("OpenAIGPTTokenizerFast") + _import_structure["models.pegasus"].append("PegasusTokenizerFast") + _import_structure["models.reformer"].append("ReformerTokenizerFast") + _import_structure["models.retribert"].append("RetriBertTokenizerFast") + _import_structure["models.roberta"].append("RobertaTokenizerFast") + _import_structure["models.squeezebert"].append("SqueezeBertTokenizerFast") + _import_structure["models.t5"].append("T5TokenizerFast") + _import_structure["models.xlm_roberta"].append("XLMRobertaTokenizerFast") + _import_structure["models.xlnet"].append("XLNetTokenizerFast") + _import_structure["tokenization_utils_fast"] = ["PreTrainedTokenizerFast"] + + if is_sentencepiece_available(): + _import_structure["convert_slow_tokenizer"] = ["SLOW_TO_FAST_CONVERTERS", "convert_slow_tokenizer"] +else: + from .utils import dummy_tokenizers_objects + + _import_structure["utils.dummy_tokenizers_objects"] = [ + name for name in dir(dummy_tokenizers_objects) if not name.startswith("_") + ] + +# PyTorch-backed objects if is_torch_available(): - - # Benchmarks - from .benchmark.benchmark import PyTorchBenchmark - from .benchmark.benchmark_args import PyTorchBenchmarkArguments - from .data.data_collator import ( - DataCollator, - DataCollatorForLanguageModeling, - DataCollatorForPermutationLanguageModeling, - DataCollatorForSOP, - DataCollatorForTokenClassification, - DataCollatorForWholeWordMask, - DataCollatorWithPadding, - default_data_collator, + _import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"] + _import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"] + _import_structure["data.data_collator"] = [ + "DataCollator", + "DataCollatorForLanguageModeling", + "DataCollatorForPermutationLanguageModeling", + "DataCollatorForSOP", + "DataCollatorForTokenClassification", + "DataCollatorForWholeWordMask", + "DataCollatorWithPadding", + "default_data_collator", + ] + _import_structure["data.datasets"] = [ + "GlueDataset", + "GlueDataTrainingArguments", + "LineByLineTextDataset", + "LineByLineWithRefDataset", + "LineByLineWithSOPTextDataset", + "SquadDataset", + "SquadDataTrainingArguments", + "TextDataset", + "TextDatasetForNextSentencePrediction", + ] + _import_structure["generation_beam_search"] = ["BeamScorer", "BeamSearchScorer"] + _import_structure["generation_logits_process"] = [ + "HammingDiversityLogitsProcessor", + "LogitsProcessor", + "LogitsProcessorList", + "LogitsWarper", + "MinLengthLogitsProcessor", + "NoBadWordsLogitsProcessor", + "NoRepeatNGramLogitsProcessor", + "PrefixConstrainedLogitsProcessor", + "RepetitionPenaltyLogitsProcessor", + "TemperatureLogitsWarper", + "TopKLogitsWarper", + "TopPLogitsWarper", + ] + _import_structure["generation_utils"] = ["top_k_top_p_filtering"] + _import_structure["modeling_utils"] = ["Conv1D", "PreTrainedModel", "apply_chunking_to_forward", "prune_layer"] + # PyTorch models structure + _import_structure["models.albert"].extend( + [ + "ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "AlbertForMaskedLM", + "AlbertForMultipleChoice", + "AlbertForPreTraining", + "AlbertForQuestionAnswering", + "AlbertForSequenceClassification", + "AlbertForTokenClassification", + "AlbertModel", + "AlbertPreTrainedModel", + "load_tf_weights_in_albert", + ] ) - from .data.datasets import ( - GlueDataset, - GlueDataTrainingArguments, - LineByLineTextDataset, - LineByLineWithRefDataset, - LineByLineWithSOPTextDataset, - SquadDataset, - SquadDataTrainingArguments, - TextDataset, - TextDatasetForNextSentencePrediction, + _import_structure["models.auto"].extend( + [ + "MODEL_FOR_CAUSAL_LM_MAPPING", + "MODEL_FOR_MASKED_LM_MAPPING", + "MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "MODEL_FOR_PRETRAINING_MAPPING", + "MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING", + "MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "MODEL_MAPPING", + "MODEL_WITH_LM_HEAD_MAPPING", + "AutoModel", + "AutoModelForCausalLM", + "AutoModelForMaskedLM", + "AutoModelForMultipleChoice", + "AutoModelForNextSentencePrediction", + "AutoModelForPreTraining", + "AutoModelForQuestionAnswering", + "AutoModelForSeq2SeqLM", + "AutoModelForSequenceClassification", + "AutoModelForTableQuestionAnswering", + "AutoModelForTokenClassification", + "AutoModelWithLMHead", + ] ) - from .generation_beam_search import BeamScorer, BeamSearchScorer - from .generation_logits_process import ( - HammingDiversityLogitsProcessor, - LogitsProcessor, - LogitsProcessorList, - LogitsWarper, - MinLengthLogitsProcessor, - NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, - PrefixConstrainedLogitsProcessor, - RepetitionPenaltyLogitsProcessor, - TemperatureLogitsWarper, - TopKLogitsWarper, - TopPLogitsWarper, + _import_structure["models.bart"].extend( + [ + "BART_PRETRAINED_MODEL_ARCHIVE_LIST", + "BartForConditionalGeneration", + "BartForQuestionAnswering", + "BartForSequenceClassification", + "BartModel", + "BartPretrainedModel", + "PretrainedBartModel", + ] ) - from .generation_utils import top_k_top_p_filtering - from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer - from .models.albert import ( - ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - AlbertForMaskedLM, - AlbertForMultipleChoice, - AlbertForPreTraining, - AlbertForQuestionAnswering, - AlbertForSequenceClassification, - AlbertForTokenClassification, - AlbertModel, - AlbertPreTrainedModel, - load_tf_weights_in_albert, + _import_structure["models.bert"].extend( + [ + "BERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BertForMaskedLM", + "BertForMultipleChoice", + "BertForNextSentencePrediction", + "BertForPreTraining", + "BertForQuestionAnswering", + "BertForSequenceClassification", + "BertForTokenClassification", + "BertLayer", + "BertLMHeadModel", + "BertModel", + "BertPreTrainedModel", + "load_tf_weights_in_bert", + ] ) - from .models.auto import ( - MODEL_FOR_CAUSAL_LM_MAPPING, - MODEL_FOR_MASKED_LM_MAPPING, - MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - MODEL_FOR_PRETRAINING_MAPPING, - MODEL_FOR_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, - MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - MODEL_MAPPING, - MODEL_WITH_LM_HEAD_MAPPING, - AutoModel, - AutoModelForCausalLM, - AutoModelForMaskedLM, - AutoModelForMultipleChoice, - AutoModelForNextSentencePrediction, - AutoModelForPreTraining, - AutoModelForQuestionAnswering, - AutoModelForSeq2SeqLM, - AutoModelForSequenceClassification, - AutoModelForTableQuestionAnswering, - AutoModelForTokenClassification, - AutoModelWithLMHead, + _import_structure["models.bert_generation"].extend( + [ + "BertGenerationDecoder", + "BertGenerationEncoder", + "load_tf_weights_in_bert_generation", + ] ) - from .models.bart import ( - BART_PRETRAINED_MODEL_ARCHIVE_LIST, - BartForConditionalGeneration, - BartForQuestionAnswering, - BartForSequenceClassification, - BartModel, - BartPretrainedModel, - PretrainedBartModel, + _import_structure["models.blenderbot"].extend( + [ + "BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotForConditionalGeneration", + "BlenderbotModel", + ] ) - from .models.bert import ( - BERT_PRETRAINED_MODEL_ARCHIVE_LIST, - BertForMaskedLM, - BertForMultipleChoice, - BertForNextSentencePrediction, - BertForPreTraining, - BertForQuestionAnswering, - BertForSequenceClassification, - BertForTokenClassification, - BertLayer, - BertLMHeadModel, - BertModel, - BertPreTrainedModel, - load_tf_weights_in_bert, + _import_structure["models.blenderbot_small"].extend( + [ + "BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST", + "BlenderbotSmallForConditionalGeneration", + "BlenderbotSmallModel", + ] ) - from .models.bert_generation import ( - BertGenerationDecoder, - BertGenerationEncoder, - load_tf_weights_in_bert_generation, + _import_structure["models.camembert"].extend( + [ + "CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "CamembertForCausalLM", + "CamembertForMaskedLM", + "CamembertForMultipleChoice", + "CamembertForQuestionAnswering", + "CamembertForSequenceClassification", + "CamembertForTokenClassification", + "CamembertModel", + ] ) - from .models.blenderbot import ( - BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, - BlenderbotForConditionalGeneration, - BlenderbotModel, + _import_structure["models.ctrl"].extend( + [ + "CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", + "CTRLForSequenceClassification", + "CTRLLMHeadModel", + "CTRLModel", + "CTRLPreTrainedModel", + ] ) - from .models.blenderbot_small import ( - BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, - BlenderbotSmallForConditionalGeneration, - BlenderbotSmallModel, + _import_structure["models.deberta"].extend( + [ + "DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "DebertaForSequenceClassification", + "DebertaModel", + "DebertaPreTrainedModel", + ] ) - from .models.camembert import ( - CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - CamembertForCausalLM, - CamembertForMaskedLM, - CamembertForMultipleChoice, - CamembertForQuestionAnswering, - CamembertForSequenceClassification, - CamembertForTokenClassification, - CamembertModel, + _import_structure["models.distilbert"].extend( + [ + "DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "DistilBertForMaskedLM", + "DistilBertForMultipleChoice", + "DistilBertForQuestionAnswering", + "DistilBertForSequenceClassification", + "DistilBertForTokenClassification", + "DistilBertModel", + "DistilBertPreTrainedModel", + ] ) - from .models.ctrl import ( - CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, - CTRLForSequenceClassification, - CTRLLMHeadModel, - CTRLModel, - CTRLPreTrainedModel, + _import_structure["models.dpr"].extend( + [ + "DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST", + "DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST", + "DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", + "DPRContextEncoder", + "DPRPretrainedContextEncoder", + "DPRPretrainedQuestionEncoder", + "DPRPretrainedReader", + "DPRQuestionEncoder", + "DPRReader", + ] ) - from .models.deberta import ( - DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - DebertaForSequenceClassification, - DebertaModel, - DebertaPreTrainedModel, + _import_structure["models.electra"].extend( + [ + "ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", + "ElectraForMaskedLM", + "ElectraForMultipleChoice", + "ElectraForPreTraining", + "ElectraForQuestionAnswering", + "ElectraForSequenceClassification", + "ElectraForTokenClassification", + "ElectraModel", + "ElectraPreTrainedModel", + "load_tf_weights_in_electra", + ] ) - from .models.distilbert import ( - DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - DistilBertForMaskedLM, - DistilBertForMultipleChoice, - DistilBertForQuestionAnswering, - DistilBertForSequenceClassification, - DistilBertForTokenClassification, - DistilBertModel, - DistilBertPreTrainedModel, + _import_structure["models.encoder_decoder"].append("EncoderDecoderModel") + _import_structure["models.flaubert"].extend( + [ + "FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "FlaubertForMultipleChoice", + "FlaubertForQuestionAnswering", + "FlaubertForQuestionAnsweringSimple", + "FlaubertForSequenceClassification", + "FlaubertForTokenClassification", + "FlaubertModel", + "FlaubertWithLMHeadModel", + ] ) - from .models.dpr import ( - DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, - DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, - DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, - DPRContextEncoder, - DPRPretrainedContextEncoder, - DPRPretrainedQuestionEncoder, - DPRPretrainedReader, - DPRQuestionEncoder, - DPRReader, + _import_structure["models.fsmt"].extend(["FSMTForConditionalGeneration", "FSMTModel", "PretrainedFSMTModel"]) + _import_structure["models.funnel"].extend( + [ + "FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST", + "FunnelBaseModel", + "FunnelForMaskedLM", + "FunnelForMultipleChoice", + "FunnelForPreTraining", + "FunnelForQuestionAnswering", + "FunnelForSequenceClassification", + "FunnelForTokenClassification", + "FunnelModel", + "load_tf_weights_in_funnel", + ] ) - from .models.electra import ( - ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, - ElectraForMaskedLM, - ElectraForMultipleChoice, - ElectraForPreTraining, - ElectraForQuestionAnswering, - ElectraForSequenceClassification, - ElectraForTokenClassification, - ElectraModel, - ElectraPreTrainedModel, - load_tf_weights_in_electra, + _import_structure["models.gpt2"].extend( + [ + "GPT2_PRETRAINED_MODEL_ARCHIVE_LIST", + "GPT2DoubleHeadsModel", + "GPT2ForSequenceClassification", + "GPT2LMHeadModel", + "GPT2Model", + "GPT2PreTrainedModel", + "load_tf_weights_in_gpt2", + ] ) - from .models.encoder_decoder import EncoderDecoderModel - from .models.flaubert import ( - FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - FlaubertForMultipleChoice, - FlaubertForQuestionAnswering, - FlaubertForQuestionAnsweringSimple, - FlaubertForSequenceClassification, - FlaubertForTokenClassification, - FlaubertModel, - FlaubertWithLMHeadModel, + _import_structure["models.layoutlm"].extend( + [ + "LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "LayoutLMForMaskedLM", + "LayoutLMForTokenClassification", + "LayoutLMModel", + ] ) - from .models.fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel - from .models.funnel import ( - FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST, - FunnelBaseModel, - FunnelForMaskedLM, - FunnelForMultipleChoice, - FunnelForPreTraining, - FunnelForQuestionAnswering, - FunnelForSequenceClassification, - FunnelForTokenClassification, - FunnelModel, - load_tf_weights_in_funnel, + _import_structure["models.led"].extend( + [ + "LED_PRETRAINED_MODEL_ARCHIVE_LIST", + "LEDForConditionalGeneration", + "LEDForQuestionAnswering", + "LEDForSequenceClassification", + "LEDModel", + ] ) - from .models.gpt2 import ( - GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, - GPT2DoubleHeadsModel, - GPT2ForSequenceClassification, - GPT2LMHeadModel, - GPT2Model, - GPT2PreTrainedModel, - load_tf_weights_in_gpt2, + _import_structure["models.longformer"].extend( + [ + "LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "LongformerForMaskedLM", + "LongformerForMultipleChoice", + "LongformerForQuestionAnswering", + "LongformerForSequenceClassification", + "LongformerForTokenClassification", + "LongformerModel", + "LongformerSelfAttention", + ] ) - from .models.layoutlm import ( - LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, - LayoutLMForMaskedLM, - LayoutLMForTokenClassification, - LayoutLMModel, + _import_structure["models.lxmert"].extend( + [ + "LxmertEncoder", + "LxmertForPreTraining", + "LxmertForQuestionAnswering", + "LxmertModel", + "LxmertPreTrainedModel", + "LxmertVisualFeatureEncoder", + "LxmertXLayer", + ] ) - from .models.led import ( - LED_PRETRAINED_MODEL_ARCHIVE_LIST, - LEDForConditionalGeneration, - LEDForQuestionAnswering, - LEDForSequenceClassification, - LEDModel, + _import_structure["models.marian"].extend(["MarianModel", "MarianMTModel"]) + _import_structure["models.mbart"].extend( + [ + "MBartForConditionalGeneration", + "MBartForQuestionAnswering", + "MBartForSequenceClassification", + "MBartModel", + ] ) - from .models.longformer import ( - LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - LongformerForMaskedLM, - LongformerForMultipleChoice, - LongformerForQuestionAnswering, - LongformerForSequenceClassification, - LongformerForTokenClassification, - LongformerModel, - LongformerSelfAttention, + _import_structure["models.mmbt"].extend(["MMBTForClassification", "MMBTModel", "ModalEmbeddings"]) + _import_structure["models.mobilebert"].extend( + [ + "MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "MobileBertForMaskedLM", + "MobileBertForMultipleChoice", + "MobileBertForNextSentencePrediction", + "MobileBertForPreTraining", + "MobileBertForQuestionAnswering", + "MobileBertForSequenceClassification", + "MobileBertForTokenClassification", + "MobileBertLayer", + "MobileBertModel", + "MobileBertPreTrainedModel", + "load_tf_weights_in_mobilebert", + ] ) - from .models.lxmert import ( - LxmertEncoder, - LxmertForPreTraining, - LxmertForQuestionAnswering, - LxmertModel, - LxmertPreTrainedModel, - LxmertVisualFeatureEncoder, - LxmertXLayer, + _import_structure["models.mpnet"].extend( + [ + "MPNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "MPNetForMaskedLM", + "MPNetForMultipleChoice", + "MPNetForQuestionAnswering", + "MPNetForSequenceClassification", + "MPNetForTokenClassification", + "MPNetLayer", + "MPNetModel", + "MPNetPreTrainedModel", + ] ) - from .models.marian import MarianModel, MarianMTModel - from .models.mbart import ( - MBartForConditionalGeneration, - MBartForQuestionAnswering, - MBartForSequenceClassification, - MBartModel, + _import_structure["models.mt5"].extend(["MT5EncoderModel", "MT5ForConditionalGeneration", "MT5Model"]) + _import_structure["models.openai"].extend( + [ + "OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST", + "OpenAIGPTDoubleHeadsModel", + "OpenAIGPTForSequenceClassification", + "OpenAIGPTLMHeadModel", + "OpenAIGPTModel", + "OpenAIGPTPreTrainedModel", + "load_tf_weights_in_openai_gpt", + ] ) - from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings - from .models.mobilebert import ( - MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - MobileBertForMaskedLM, - MobileBertForMultipleChoice, - MobileBertForNextSentencePrediction, - MobileBertForPreTraining, - MobileBertForQuestionAnswering, - MobileBertForSequenceClassification, - MobileBertForTokenClassification, - MobileBertLayer, - MobileBertModel, - MobileBertPreTrainedModel, - load_tf_weights_in_mobilebert, + _import_structure["models.pegasus"].extend(["PegasusForConditionalGeneration", "PegasusModel"]) + _import_structure["models.prophetnet"].extend( + [ + "PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "ProphetNetDecoder", + "ProphetNetEncoder", + "ProphetNetForCausalLM", + "ProphetNetForConditionalGeneration", + "ProphetNetModel", + "ProphetNetPreTrainedModel", + ] ) - from .models.mpnet import ( - MPNET_PRETRAINED_MODEL_ARCHIVE_LIST, - MPNetForMaskedLM, - MPNetForMultipleChoice, - MPNetForQuestionAnswering, - MPNetForSequenceClassification, - MPNetForTokenClassification, - MPNetLayer, - MPNetModel, - MPNetPreTrainedModel, + _import_structure["models.rag"].extend(["RagModel", "RagSequenceForGeneration", "RagTokenForGeneration"]) + _import_structure["models.reformer"].extend( + [ + "REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "ReformerAttention", + "ReformerForMaskedLM", + "ReformerForQuestionAnswering", + "ReformerForSequenceClassification", + "ReformerLayer", + "ReformerModel", + "ReformerModelWithLMHead", + ] ) - from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model - from .models.openai import ( - OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, - OpenAIGPTDoubleHeadsModel, - OpenAIGPTForSequenceClassification, - OpenAIGPTLMHeadModel, - OpenAIGPTModel, - OpenAIGPTPreTrainedModel, - load_tf_weights_in_openai_gpt, + _import_structure["models.retribert"].extend( + ["RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "RetriBertModel", "RetriBertPreTrainedModel"] ) - from .models.pegasus import PegasusForConditionalGeneration, PegasusModel - from .models.prophetnet import ( - PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, - ProphetNetDecoder, - ProphetNetEncoder, - ProphetNetForCausalLM, - ProphetNetForConditionalGeneration, - ProphetNetModel, - ProphetNetPreTrainedModel, + _import_structure["models.roberta"].extend( + [ + "ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "RobertaForCausalLM", + "RobertaForMaskedLM", + "RobertaForMultipleChoice", + "RobertaForQuestionAnswering", + "RobertaForSequenceClassification", + "RobertaForTokenClassification", + "RobertaModel", + ] ) - from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration - from .models.reformer import ( - REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - ReformerAttention, - ReformerForMaskedLM, - ReformerForQuestionAnswering, - ReformerForSequenceClassification, - ReformerLayer, - ReformerModel, - ReformerModelWithLMHead, + _import_structure["models.squeezebert"].extend( + [ + "SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "SqueezeBertForMaskedLM", + "SqueezeBertForMultipleChoice", + "SqueezeBertForQuestionAnswering", + "SqueezeBertForSequenceClassification", + "SqueezeBertForTokenClassification", + "SqueezeBertModel", + "SqueezeBertModule", + "SqueezeBertPreTrainedModel", + ] ) - from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel - from .models.roberta import ( - ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - RobertaForCausalLM, - RobertaForMaskedLM, - RobertaForMultipleChoice, - RobertaForQuestionAnswering, - RobertaForSequenceClassification, - RobertaForTokenClassification, - RobertaModel, + _import_structure["models.t5"].extend( + [ + "T5_PRETRAINED_MODEL_ARCHIVE_LIST", + "T5EncoderModel", + "T5ForConditionalGeneration", + "T5Model", + "T5PreTrainedModel", + "load_tf_weights_in_t5", + ] ) - from .models.squeezebert import ( - SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - SqueezeBertForMaskedLM, - SqueezeBertForMultipleChoice, - SqueezeBertForQuestionAnswering, - SqueezeBertForSequenceClassification, - SqueezeBertForTokenClassification, - SqueezeBertModel, - SqueezeBertModule, - SqueezeBertPreTrainedModel, + _import_structure["models.tapas"].extend( + [ + "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", + "TapasForMaskedLM", + "TapasForQuestionAnswering", + "TapasForSequenceClassification", + "TapasModel", + ] ) - from .models.t5 import ( - T5_PRETRAINED_MODEL_ARCHIVE_LIST, - T5EncoderModel, - T5ForConditionalGeneration, - T5Model, - T5PreTrainedModel, - load_tf_weights_in_t5, + _import_structure["models.transfo_xl"].extend( + [ + "TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST", + "AdaptiveEmbedding", + "TransfoXLForSequenceClassification", + "TransfoXLLMHeadModel", + "TransfoXLModel", + "TransfoXLPreTrainedModel", + "load_tf_weights_in_transfo_xl", + ] ) - from .models.tapas import ( - TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, - TapasForMaskedLM, - TapasForQuestionAnswering, - TapasForSequenceClassification, - TapasModel, + _import_structure["models.xlm"].extend( + [ + "XLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "XLMForMultipleChoice", + "XLMForQuestionAnswering", + "XLMForQuestionAnsweringSimple", + "XLMForSequenceClassification", + "XLMForTokenClassification", + "XLMModel", + "XLMPreTrainedModel", + "XLMWithLMHeadModel", + ] ) - from .models.transfo_xl import ( - TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, - AdaptiveEmbedding, - TransfoXLForSequenceClassification, - TransfoXLLMHeadModel, - TransfoXLModel, - TransfoXLPreTrainedModel, - load_tf_weights_in_transfo_xl, + _import_structure["models.xlm_prophetnet"].extend( + [ + "XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "XLMProphetNetDecoder", + "XLMProphetNetEncoder", + "XLMProphetNetForCausalLM", + "XLMProphetNetForConditionalGeneration", + "XLMProphetNetModel", + ] ) - from .models.xlm import ( - XLM_PRETRAINED_MODEL_ARCHIVE_LIST, - XLMForMultipleChoice, - XLMForQuestionAnswering, - XLMForQuestionAnsweringSimple, - XLMForSequenceClassification, - XLMForTokenClassification, - XLMModel, - XLMPreTrainedModel, - XLMWithLMHeadModel, + _import_structure["models.xlm_roberta"].extend( + [ + "XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "XLMRobertaForCausalLM", + "XLMRobertaForMaskedLM", + "XLMRobertaForMultipleChoice", + "XLMRobertaForQuestionAnswering", + "XLMRobertaForSequenceClassification", + "XLMRobertaForTokenClassification", + "XLMRobertaModel", + ] ) - from .models.xlm_prophetnet import ( - XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, - XLMProphetNetDecoder, - XLMProphetNetEncoder, - XLMProphetNetForCausalLM, - XLMProphetNetForConditionalGeneration, - XLMProphetNetModel, + _import_structure["models.xlnet"].extend( + [ + "XLNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "XLNetForMultipleChoice", + "XLNetForQuestionAnswering", + "XLNetForQuestionAnsweringSimple", + "XLNetForSequenceClassification", + "XLNetForTokenClassification", + "XLNetLMHeadModel", + "XLNetModel", + "XLNetPreTrainedModel", + "load_tf_weights_in_xlnet", + ] ) - from .models.xlm_roberta import ( - XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - XLMRobertaForCausalLM, - XLMRobertaForMaskedLM, - XLMRobertaForMultipleChoice, - XLMRobertaForQuestionAnswering, - XLMRobertaForSequenceClassification, - XLMRobertaForTokenClassification, - XLMRobertaModel, - ) - from .models.xlnet import ( - XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, - XLNetForMultipleChoice, - XLNetForQuestionAnswering, - XLNetForQuestionAnsweringSimple, - XLNetForSequenceClassification, - XLNetForTokenClassification, - XLNetLMHeadModel, - XLNetModel, - XLNetPreTrainedModel, - load_tf_weights_in_xlnet, - ) - - # Optimization - from .optimization import ( - Adafactor, - AdamW, - get_constant_schedule, - get_constant_schedule_with_warmup, - get_cosine_schedule_with_warmup, - get_cosine_with_hard_restarts_schedule_with_warmup, - get_linear_schedule_with_warmup, - get_polynomial_decay_schedule_with_warmup, - get_scheduler, - ) - - # Trainer - from .trainer import Trainer - from .trainer_pt_utils import torch_distributed_zero_first - from .trainer_seq2seq import Seq2SeqTrainer + _import_structure["optimization"] = [ + "Adafactor", + "AdamW", + "get_constant_schedule", + "get_constant_schedule_with_warmup", + "get_cosine_schedule_with_warmup", + "get_cosine_with_hard_restarts_schedule_with_warmup", + "get_linear_schedule_with_warmup", + "get_polynomial_decay_schedule_with_warmup", + "get_scheduler", + ] + _import_structure["trainer"] = ["Trainer"] + _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] + _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] else: - from .utils.dummy_pt_objects import * + from .utils import dummy_pt_objects -# TensorFlow + _import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")] + +# TensorFlow-backed objects if is_tf_available(): - - from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments - - # Benchmarks - from .benchmark.benchmark_tf import TensorFlowBenchmark - from .generation_tf_utils import tf_top_k_top_p_filtering - from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list - from .models.albert import ( - TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFAlbertForMaskedLM, - TFAlbertForMultipleChoice, - TFAlbertForPreTraining, - TFAlbertForQuestionAnswering, - TFAlbertForSequenceClassification, - TFAlbertForTokenClassification, - TFAlbertMainLayer, - TFAlbertModel, - TFAlbertPreTrainedModel, + _import_structure["benchmark.benchmark_args_tf"] = ["TensorFlowBenchmarkArguments"] + _import_structure["benchmark.benchmark_tf"] = ["TensorFlowBenchmark"] + _import_structure["generation_tf_utils"] = ["tf_top_k_top_p_filtering"] + _import_structure["modeling_tf_utils"] = [ + "TFPreTrainedModel", + "TFSequenceSummary", + "TFSharedEmbeddings", + "shape_list", + ] + # TensorFlow models structure + _import_structure["models.albert"].extend( + [ + "TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFAlbertForMaskedLM", + "TFAlbertForMultipleChoice", + "TFAlbertForPreTraining", + "TFAlbertForQuestionAnswering", + "TFAlbertForSequenceClassification", + "TFAlbertForTokenClassification", + "TFAlbertMainLayer", + "TFAlbertModel", + "TFAlbertPreTrainedModel", + ] ) + _import_structure["models.auto"].extend( + [ + "TF_MODEL_FOR_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_MASKED_LM_MAPPING", + "TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING", + "TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING", + "TF_MODEL_FOR_PRETRAINING_MAPPING", + "TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING", + "TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING", + "TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING", + "TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING", + "TF_MODEL_MAPPING", + "TF_MODEL_WITH_LM_HEAD_MAPPING", + "TFAutoModel", + "TFAutoModelForCausalLM", + "TFAutoModelForMaskedLM", + "TFAutoModelForMultipleChoice", + "TFAutoModelForPreTraining", + "TFAutoModelForQuestionAnswering", + "TFAutoModelForSeq2SeqLM", + "TFAutoModelForSequenceClassification", + "TFAutoModelForTokenClassification", + "TFAutoModelWithLMHead", + ] + ) + _import_structure["models.bart"].extend(["TFBartForConditionalGeneration", "TFBartModel", "TFBartPretrainedModel"]) + _import_structure["models.bert"].extend( + [ + "TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFBertEmbeddings", + "TFBertForMaskedLM", + "TFBertForMultipleChoice", + "TFBertForNextSentencePrediction", + "TFBertForPreTraining", + "TFBertForQuestionAnswering", + "TFBertForSequenceClassification", + "TFBertForTokenClassification", + "TFBertLMHeadModel", + "TFBertMainLayer", + "TFBertModel", + "TFBertPreTrainedModel", + ] + ) + _import_structure["models.blenderbot"].append("TFBlenderbotForConditionalGeneration") + _import_structure["models.camembert"].extend( + [ + "TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFCamembertForMaskedLM", + "TFCamembertForMultipleChoice", + "TFCamembertForQuestionAnswering", + "TFCamembertForSequenceClassification", + "TFCamembertForTokenClassification", + "TFCamembertModel", + ] + ) + _import_structure["models.ctrl"].extend( + [ + "TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFCTRLForSequenceClassification", + "TFCTRLLMHeadModel", + "TFCTRLModel", + "TFCTRLPreTrainedModel", + ] + ) + _import_structure["models.distilbert"].extend( + [ + "TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFDistilBertForMaskedLM", + "TFDistilBertForMultipleChoice", + "TFDistilBertForQuestionAnswering", + "TFDistilBertForSequenceClassification", + "TFDistilBertForTokenClassification", + "TFDistilBertMainLayer", + "TFDistilBertModel", + "TFDistilBertPreTrainedModel", + ] + ) + _import_structure["models.dpr"].extend( + [ + "TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFDPRContextEncoder", + "TFDPRPretrainedContextEncoder", + "TFDPRPretrainedQuestionEncoder", + "TFDPRPretrainedReader", + "TFDPRQuestionEncoder", + "TFDPRReader", + ] + ) + _import_structure["models.electra"].extend( + [ + "TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFElectraForMaskedLM", + "TFElectraForMultipleChoice", + "TFElectraForPreTraining", + "TFElectraForQuestionAnswering", + "TFElectraForSequenceClassification", + "TFElectraForTokenClassification", + "TFElectraModel", + "TFElectraPreTrainedModel", + ] + ) + _import_structure["models.flaubert"].extend( + [ + "TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFFlaubertForMultipleChoice", + "TFFlaubertForQuestionAnsweringSimple", + "TFFlaubertForSequenceClassification", + "TFFlaubertForTokenClassification", + "TFFlaubertModel", + "TFFlaubertWithLMHeadModel", + ] + ) + _import_structure["models.funnel"].extend( + [ + "TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFFunnelBaseModel", + "TFFunnelForMaskedLM", + "TFFunnelForMultipleChoice", + "TFFunnelForPreTraining", + "TFFunnelForQuestionAnswering", + "TFFunnelForSequenceClassification", + "TFFunnelForTokenClassification", + "TFFunnelModel", + ] + ) + _import_structure["models.gpt2"].extend( + [ + "TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFGPT2DoubleHeadsModel", + "TFGPT2ForSequenceClassification", + "TFGPT2LMHeadModel", + "TFGPT2MainLayer", + "TFGPT2Model", + "TFGPT2PreTrainedModel", + ] + ) + _import_structure["models.led"].extend(["TFLEDForConditionalGeneration", "TFLEDModel", "TFLEDPreTrainedModel"]) + _import_structure["models.longformer"].extend( + [ + "TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFLongformerForMaskedLM", + "TFLongformerForMultipleChoice", + "TFLongformerForQuestionAnswering", + "TFLongformerForSequenceClassification", + "TFLongformerForTokenClassification", + "TFLongformerModel", + "TFLongformerSelfAttention", + ] + ) + _import_structure["models.lxmert"].extend( + [ + "TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFLxmertForPreTraining", + "TFLxmertMainLayer", + "TFLxmertModel", + "TFLxmertPreTrainedModel", + "TFLxmertVisualFeatureEncoder", + ] + ) + _import_structure["models.marian"].append("TFMarianMTModel") + _import_structure["models.mbart"].append("TFMBartForConditionalGeneration") + _import_structure["models.mobilebert"].extend( + [ + "TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFMobileBertForMaskedLM", + "TFMobileBertForMultipleChoice", + "TFMobileBertForNextSentencePrediction", + "TFMobileBertForPreTraining", + "TFMobileBertForQuestionAnswering", + "TFMobileBertForSequenceClassification", + "TFMobileBertForTokenClassification", + "TFMobileBertMainLayer", + "TFMobileBertModel", + "TFMobileBertPreTrainedModel", + ] + ) + _import_structure["models.mpnet"].extend( + [ + "TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFMPNetForMaskedLM", + "TFMPNetForMultipleChoice", + "TFMPNetForQuestionAnswering", + "TFMPNetForSequenceClassification", + "TFMPNetForTokenClassification", + "TFMPNetMainLayer", + "TFMPNetModel", + "TFMPNetPreTrainedModel", + ] + ) + _import_structure["models.mt5"].extend(["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]) + _import_structure["models.openai"].extend( + [ + "TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFOpenAIGPTDoubleHeadsModel", + "TFOpenAIGPTForSequenceClassification", + "TFOpenAIGPTLMHeadModel", + "TFOpenAIGPTMainLayer", + "TFOpenAIGPTModel", + "TFOpenAIGPTPreTrainedModel", + ] + ) + _import_structure["models.pegasus"].append("TFPegasusForConditionalGeneration") + _import_structure["models.roberta"].extend( + [ + "TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFRobertaForMaskedLM", + "TFRobertaForMultipleChoice", + "TFRobertaForQuestionAnswering", + "TFRobertaForSequenceClassification", + "TFRobertaForTokenClassification", + "TFRobertaMainLayer", + "TFRobertaModel", + "TFRobertaPreTrainedModel", + ] + ) + _import_structure["models.t5"].extend( + [ + "TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFT5EncoderModel", + "TFT5ForConditionalGeneration", + "TFT5Model", + "TFT5PreTrainedModel", + ] + ) + _import_structure["models.transfo_xl"].extend( + [ + "TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFAdaptiveEmbedding", + "TFTransfoXLForSequenceClassification", + "TFTransfoXLLMHeadModel", + "TFTransfoXLMainLayer", + "TFTransfoXLModel", + "TFTransfoXLPreTrainedModel", + ] + ) + _import_structure["models.xlm"].extend( + [ + "TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFXLMForMultipleChoice", + "TFXLMForQuestionAnsweringSimple", + "TFXLMForSequenceClassification", + "TFXLMForTokenClassification", + "TFXLMMainLayer", + "TFXLMModel", + "TFXLMPreTrainedModel", + "TFXLMWithLMHeadModel", + ] + ) + _import_structure["models.xlm_roberta"].extend( + [ + "TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFXLMRobertaForMaskedLM", + "TFXLMRobertaForMultipleChoice", + "TFXLMRobertaForQuestionAnswering", + "TFXLMRobertaForSequenceClassification", + "TFXLMRobertaForTokenClassification", + "TFXLMRobertaModel", + ] + ) + _import_structure["models.xlnet"].extend( + [ + "TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST", + "TFXLNetForMultipleChoice", + "TFXLNetForQuestionAnsweringSimple", + "TFXLNetForSequenceClassification", + "TFXLNetForTokenClassification", + "TFXLNetLMHeadModel", + "TFXLNetMainLayer", + "TFXLNetModel", + "TFXLNetPreTrainedModel", + ] + ) + _import_structure["optimization_tf"] = ["AdamWeightDecay", "GradientAccumulator", "WarmUp", "create_optimizer"] + _import_structure["trainer_tf"] = ["TFTrainer"] + +else: + from .utils import dummy_tf_objects + + _import_structure["utils.dummy_tf_objects"] = [name for name in dir(dummy_tf_objects) if not name.startswith("_")] + +# FLAX-backed objects +if is_flax_available(): + _import_structure["modeling_flax_utils"] = ["FlaxPreTrainedModel"] + _import_structure["models.auto"].extend(["FLAX_MODEL_MAPPING", "FlaxAutoModel"]) + _import_structure["models.bert"].extend(["FlaxBertForMaskedLM", "FlaxBertModel"]) + _import_structure["models.roberta"].append("FlaxRobertaModel") +else: + from .utils import dummy_flax_objects + + _import_structure["utils.dummy_flax_objects"] = [ + name for name in dir(dummy_flax_objects) if not name.startswith("_") + ] + + +# Direct imports for type-checking +if TYPE_CHECKING: + # Configuration + from .configuration_utils import PretrainedConfig + + # Data + from .data import ( + DataProcessor, + InputExample, + InputFeatures, + SingleSentenceClassificationProcessor, + SquadExample, + SquadFeatures, + SquadV1Processor, + SquadV2Processor, + glue_compute_metrics, + glue_convert_examples_to_features, + glue_output_modes, + glue_processors, + glue_tasks_num_labels, + squad_convert_examples_to_features, + xnli_compute_metrics, + xnli_output_modes, + xnli_processors, + xnli_tasks_num_labels, + ) + + # Files and general utilities + from .file_utils import ( + CONFIG_NAME, + MODEL_CARD_NAME, + PYTORCH_PRETRAINED_BERT_CACHE, + PYTORCH_TRANSFORMERS_CACHE, + SPIECE_UNDERLINE, + TF2_WEIGHTS_NAME, + TF_WEIGHTS_NAME, + TRANSFORMERS_CACHE, + WEIGHTS_NAME, + add_end_docstrings, + add_start_docstrings, + cached_path, + is_apex_available, + is_datasets_available, + is_faiss_available, + is_flax_available, + is_psutil_available, + is_py3nvml_available, + is_sentencepiece_available, + is_sklearn_available, + is_tf_available, + is_tokenizers_available, + is_torch_available, + is_torch_tpu_available, + ) + from .hf_argparser import HfArgumentParser + + # Integrations + from .integrations import ( + is_comet_available, + is_optuna_available, + is_ray_available, + is_ray_tune_available, + is_tensorboard_available, + is_wandb_available, + ) + + # Model Cards + from .modelcard import ModelCard + + # TF 2.0 <=> PyTorch conversion utilities + from .modeling_tf_pytorch_utils import ( + convert_tf_weight_name_to_pt_weight_name, + load_pytorch_checkpoint_in_tf2_model, + load_pytorch_model_in_tf2_model, + load_pytorch_weights_in_tf2_model, + load_tf2_checkpoint_in_pytorch_model, + load_tf2_model_in_pytorch_model, + load_tf2_weights_in_pytorch_model, + ) + from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig from .models.auto import ( - TF_MODEL_FOR_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_MASKED_LM_MAPPING, - TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, - TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, - TF_MODEL_FOR_PRETRAINING_MAPPING, - TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, - TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, - TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, - TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, - TF_MODEL_MAPPING, - TF_MODEL_WITH_LM_HEAD_MAPPING, - TFAutoModel, - TFAutoModelForCausalLM, - TFAutoModelForMaskedLM, - TFAutoModelForMultipleChoice, - TFAutoModelForPreTraining, - TFAutoModelForQuestionAnswering, - TFAutoModelForSeq2SeqLM, - TFAutoModelForSequenceClassification, - TFAutoModelForTokenClassification, - TFAutoModelWithLMHead, + ALL_PRETRAINED_CONFIG_ARCHIVE_MAP, + CONFIG_MAPPING, + MODEL_NAMES_MAPPING, + TOKENIZER_MAPPING, + AutoConfig, + AutoTokenizer, ) - from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel + from .models.bart import BartConfig, BartTokenizer from .models.bert import ( - TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFBertEmbeddings, - TFBertForMaskedLM, - TFBertForMultipleChoice, - TFBertForNextSentencePrediction, - TFBertForPreTraining, - TFBertForQuestionAnswering, - TFBertForSequenceClassification, - TFBertForTokenClassification, - TFBertLMHeadModel, - TFBertMainLayer, - TFBertModel, - TFBertPreTrainedModel, + BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, + BasicTokenizer, + BertConfig, + BertTokenizer, + WordpieceTokenizer, ) - from .models.blenderbot import TFBlenderbotForConditionalGeneration - from .models.camembert import ( - TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFCamembertForMaskedLM, - TFCamembertForMultipleChoice, - TFCamembertForQuestionAnswering, - TFCamembertForSequenceClassification, - TFCamembertForTokenClassification, - TFCamembertModel, - ) - from .models.ctrl import ( - TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, - TFCTRLForSequenceClassification, - TFCTRLLMHeadModel, - TFCTRLModel, - TFCTRLPreTrainedModel, - ) - from .models.distilbert import ( - TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFDistilBertForMaskedLM, - TFDistilBertForMultipleChoice, - TFDistilBertForQuestionAnswering, - TFDistilBertForSequenceClassification, - TFDistilBertForTokenClassification, - TFDistilBertMainLayer, - TFDistilBertModel, - TFDistilBertPreTrainedModel, + from .models.bert_generation import BertGenerationConfig + from .models.bert_japanese import BertJapaneseTokenizer, CharacterTokenizer, MecabTokenizer + from .models.bertweet import BertweetTokenizer + from .models.blenderbot import BLENDERBOT_PRETRAINED_CONFIG_ARCHIVE_MAP, BlenderbotConfig, BlenderbotTokenizer + from .models.blenderbot_small import ( + BLENDERBOT_SMALL_PRETRAINED_CONFIG_ARCHIVE_MAP, + BlenderbotSmallConfig, + BlenderbotSmallTokenizer, ) + from .models.camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig + from .models.ctrl import CTRL_PRETRAINED_CONFIG_ARCHIVE_MAP, CTRLConfig, CTRLTokenizer + from .models.deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaTokenizer + from .models.distilbert import DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, DistilBertConfig, DistilBertTokenizer from .models.dpr import ( - TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, - TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, - TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, - TFDPRContextEncoder, - TFDPRPretrainedContextEncoder, - TFDPRPretrainedQuestionEncoder, - TFDPRPretrainedReader, - TFDPRQuestionEncoder, - TFDPRReader, - ) - from .models.electra import ( - TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, - TFElectraForMaskedLM, - TFElectraForMultipleChoice, - TFElectraForPreTraining, - TFElectraForQuestionAnswering, - TFElectraForSequenceClassification, - TFElectraForTokenClassification, - TFElectraModel, - TFElectraPreTrainedModel, - ) - from .models.flaubert import ( - TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFFlaubertForMultipleChoice, - TFFlaubertForQuestionAnsweringSimple, - TFFlaubertForSequenceClassification, - TFFlaubertForTokenClassification, - TFFlaubertModel, - TFFlaubertWithLMHeadModel, - ) - from .models.funnel import ( - TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST, - TFFunnelBaseModel, - TFFunnelForMaskedLM, - TFFunnelForMultipleChoice, - TFFunnelForPreTraining, - TFFunnelForQuestionAnswering, - TFFunnelForSequenceClassification, - TFFunnelForTokenClassification, - TFFunnelModel, - ) - from .models.gpt2 import ( - TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, - TFGPT2DoubleHeadsModel, - TFGPT2ForSequenceClassification, - TFGPT2LMHeadModel, - TFGPT2MainLayer, - TFGPT2Model, - TFGPT2PreTrainedModel, - ) - from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel - from .models.longformer import ( - TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, - TFLongformerForMaskedLM, - TFLongformerForMultipleChoice, - TFLongformerForQuestionAnswering, - TFLongformerForSequenceClassification, - TFLongformerForTokenClassification, - TFLongformerModel, - TFLongformerSelfAttention, - ) - from .models.lxmert import ( - TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFLxmertForPreTraining, - TFLxmertMainLayer, - TFLxmertModel, - TFLxmertPreTrainedModel, - TFLxmertVisualFeatureEncoder, - ) - from .models.marian import TFMarianMTModel - from .models.mbart import TFMBartForConditionalGeneration - from .models.mobilebert import ( - TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFMobileBertForMaskedLM, - TFMobileBertForMultipleChoice, - TFMobileBertForNextSentencePrediction, - TFMobileBertForPreTraining, - TFMobileBertForQuestionAnswering, - TFMobileBertForSequenceClassification, - TFMobileBertForTokenClassification, - TFMobileBertMainLayer, - TFMobileBertModel, - TFMobileBertPreTrainedModel, - ) - from .models.mpnet import ( - TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST, - TFMPNetForMaskedLM, - TFMPNetForMultipleChoice, - TFMPNetForQuestionAnswering, - TFMPNetForSequenceClassification, - TFMPNetForTokenClassification, - TFMPNetMainLayer, - TFMPNetModel, - TFMPNetPreTrainedModel, - ) - from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model - from .models.openai import ( - TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, - TFOpenAIGPTDoubleHeadsModel, - TFOpenAIGPTForSequenceClassification, - TFOpenAIGPTLMHeadModel, - TFOpenAIGPTMainLayer, - TFOpenAIGPTModel, - TFOpenAIGPTPreTrainedModel, - ) - from .models.pegasus import TFPegasusForConditionalGeneration - from .models.roberta import ( - TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - TFRobertaForMaskedLM, - TFRobertaForMultipleChoice, - TFRobertaForQuestionAnswering, - TFRobertaForSequenceClassification, - TFRobertaForTokenClassification, - TFRobertaMainLayer, - TFRobertaModel, - TFRobertaPreTrainedModel, - ) - from .models.t5 import ( - TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST, - TFT5EncoderModel, - TFT5ForConditionalGeneration, - TFT5Model, - TFT5PreTrainedModel, + DPR_PRETRAINED_CONFIG_ARCHIVE_MAP, + DPRConfig, + DPRContextEncoderTokenizer, + DPRQuestionEncoderTokenizer, + DPRReaderOutput, + DPRReaderTokenizer, ) + from .models.electra import ELECTRA_PRETRAINED_CONFIG_ARCHIVE_MAP, ElectraConfig, ElectraTokenizer + from .models.encoder_decoder import EncoderDecoderConfig + from .models.flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertTokenizer + from .models.fsmt import FSMT_PRETRAINED_CONFIG_ARCHIVE_MAP, FSMTConfig, FSMTTokenizer + from .models.funnel import FUNNEL_PRETRAINED_CONFIG_ARCHIVE_MAP, FunnelConfig, FunnelTokenizer + from .models.gpt2 import GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP, GPT2Config, GPT2Tokenizer + from .models.herbert import HerbertTokenizer + from .models.layoutlm import LAYOUTLM_PRETRAINED_CONFIG_ARCHIVE_MAP, LayoutLMConfig, LayoutLMTokenizer + from .models.led import LED_PRETRAINED_CONFIG_ARCHIVE_MAP, LEDConfig, LEDTokenizer + from .models.longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig, LongformerTokenizer + from .models.lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, LxmertConfig, LxmertTokenizer + from .models.marian import MarianConfig + from .models.mbart import MBartConfig + from .models.mmbt import MMBTConfig + from .models.mobilebert import MOBILEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, MobileBertConfig, MobileBertTokenizer + from .models.mpnet import MPNET_PRETRAINED_CONFIG_ARCHIVE_MAP, MPNetConfig, MPNetTokenizer + from .models.mt5 import MT5Config + from .models.openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig, OpenAIGPTTokenizer + from .models.pegasus import PegasusConfig + from .models.phobert import PhobertTokenizer + from .models.prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig, ProphetNetTokenizer + from .models.rag import RagConfig, RagRetriever, RagTokenizer + from .models.reformer import REFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, ReformerConfig + from .models.retribert import RETRIBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, RetriBertConfig, RetriBertTokenizer + from .models.roberta import ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, RobertaConfig, RobertaTokenizer + from .models.squeezebert import SQUEEZEBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, SqueezeBertConfig, SqueezeBertTokenizer + from .models.t5 import T5_PRETRAINED_CONFIG_ARCHIVE_MAP, T5Config + from .models.tapas import TAPAS_PRETRAINED_CONFIG_ARCHIVE_MAP, TapasConfig, TapasTokenizer from .models.transfo_xl import ( - TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, - TFAdaptiveEmbedding, - TFTransfoXLForSequenceClassification, - TFTransfoXLLMHeadModel, - TFTransfoXLMainLayer, - TFTransfoXLModel, - TFTransfoXLPreTrainedModel, + TRANSFO_XL_PRETRAINED_CONFIG_ARCHIVE_MAP, + TransfoXLConfig, + TransfoXLCorpus, + TransfoXLTokenizer, ) - from .models.xlm import ( - TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, - TFXLMForMultipleChoice, - TFXLMForQuestionAnsweringSimple, - TFXLMForSequenceClassification, - TFXLMForTokenClassification, - TFXLMMainLayer, - TFXLMModel, - TFXLMPreTrainedModel, - TFXLMWithLMHeadModel, - ) - from .models.xlm_roberta import ( - TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, - TFXLMRobertaForMaskedLM, - TFXLMRobertaForMultipleChoice, - TFXLMRobertaForQuestionAnswering, - TFXLMRobertaForSequenceClassification, - TFXLMRobertaForTokenClassification, - TFXLMRobertaModel, - ) - from .models.xlnet import ( - TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, - TFXLNetForMultipleChoice, - TFXLNetForQuestionAnsweringSimple, - TFXLNetForSequenceClassification, - TFXLNetForTokenClassification, - TFXLNetLMHeadModel, - TFXLNetMainLayer, - TFXLNetModel, - TFXLNetPreTrainedModel, + from .models.xlm import XLM_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMConfig, XLMTokenizer + from .models.xlm_prophetnet import XLM_PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMProphetNetConfig + from .models.xlm_roberta import XLM_ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, XLMRobertaConfig + from .models.xlnet import XLNET_PRETRAINED_CONFIG_ARCHIVE_MAP, XLNetConfig + + # Pipelines + from .pipelines import ( + Conversation, + ConversationalPipeline, + CsvPipelineDataFormat, + FeatureExtractionPipeline, + FillMaskPipeline, + JsonPipelineDataFormat, + NerPipeline, + PipedPipelineDataFormat, + Pipeline, + PipelineDataFormat, + QuestionAnsweringPipeline, + SummarizationPipeline, + TableQuestionAnsweringPipeline, + Text2TextGenerationPipeline, + TextClassificationPipeline, + TextGenerationPipeline, + TokenClassificationPipeline, + TranslationPipeline, + ZeroShotClassificationPipeline, + pipeline, ) - # Optimization - from .optimization_tf import AdamWeightDecay, GradientAccumulator, WarmUp, create_optimizer + # Tokenization + from .tokenization_utils import PreTrainedTokenizer + from .tokenization_utils_base import ( + AddedToken, + BatchEncoding, + CharSpan, + PreTrainedTokenizerBase, + SpecialTokensMixin, + TensorType, + TokenSpan, + ) # Trainer - from .trainer_tf import TFTrainer + from .trainer_callback import ( + DefaultFlowCallback, + EarlyStoppingCallback, + PrinterCallback, + ProgressCallback, + TrainerCallback, + TrainerControl, + TrainerState, + ) + from .trainer_utils import EvalPrediction, EvaluationStrategy, SchedulerType, set_seed + from .training_args import TrainingArguments + from .training_args_seq2seq import Seq2SeqTrainingArguments + from .training_args_tf import TFTrainingArguments + if is_sentencepiece_available(): + from .models.albert import AlbertTokenizer + from .models.barthez import BarthezTokenizer + from .models.bert_generation import BertGenerationTokenizer + from .models.camembert import CamembertTokenizer + from .models.marian import MarianTokenizer + from .models.mbart import MBartTokenizer + from .models.mt5 import MT5Tokenizer + from .models.pegasus import PegasusTokenizer + from .models.reformer import ReformerTokenizer + from .models.t5 import T5Tokenizer + from .models.xlm_prophetnet import XLMProphetNetTokenizer + from .models.xlm_roberta import XLMRobertaTokenizer + from .models.xlnet import XLNetTokenizer + else: + from .utils.dummy_sentencepiece_objects import * + + if is_tokenizers_available(): + from .models.albert import AlbertTokenizerFast + from .models.bart import BartTokenizerFast + from .models.barthez import BarthezTokenizerFast + from .models.bert import BertTokenizerFast + from .models.camembert import CamembertTokenizerFast + from .models.distilbert import DistilBertTokenizerFast + from .models.dpr import DPRContextEncoderTokenizerFast, DPRQuestionEncoderTokenizerFast, DPRReaderTokenizerFast + from .models.electra import ElectraTokenizerFast + from .models.funnel import FunnelTokenizerFast + from .models.gpt2 import GPT2TokenizerFast + from .models.herbert import HerbertTokenizerFast + from .models.layoutlm import LayoutLMTokenizerFast + from .models.led import LEDTokenizerFast + from .models.longformer import LongformerTokenizerFast + from .models.lxmert import LxmertTokenizerFast + from .models.mbart import MBartTokenizerFast + from .models.mobilebert import MobileBertTokenizerFast + from .models.mpnet import MPNetTokenizerFast + from .models.mt5 import MT5TokenizerFast + from .models.openai import OpenAIGPTTokenizerFast + from .models.pegasus import PegasusTokenizerFast + from .models.reformer import ReformerTokenizerFast + from .models.retribert import RetriBertTokenizerFast + from .models.roberta import RobertaTokenizerFast + from .models.squeezebert import SqueezeBertTokenizerFast + from .models.t5 import T5TokenizerFast + from .models.xlm_roberta import XLMRobertaTokenizerFast + from .models.xlnet import XLNetTokenizerFast + from .tokenization_utils_fast import PreTrainedTokenizerFast + + if is_sentencepiece_available(): + from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS, convert_slow_tokenizer + else: + from .utils.dummy_tokenizers_objects import * + + # Modeling + if is_torch_available(): + + # Benchmarks + from .benchmark.benchmark import PyTorchBenchmark + from .benchmark.benchmark_args import PyTorchBenchmarkArguments + from .data.data_collator import ( + DataCollator, + DataCollatorForLanguageModeling, + DataCollatorForPermutationLanguageModeling, + DataCollatorForSOP, + DataCollatorForTokenClassification, + DataCollatorForWholeWordMask, + DataCollatorWithPadding, + default_data_collator, + ) + from .data.datasets import ( + GlueDataset, + GlueDataTrainingArguments, + LineByLineTextDataset, + LineByLineWithRefDataset, + LineByLineWithSOPTextDataset, + SquadDataset, + SquadDataTrainingArguments, + TextDataset, + TextDatasetForNextSentencePrediction, + ) + from .generation_beam_search import BeamScorer, BeamSearchScorer + from .generation_logits_process import ( + HammingDiversityLogitsProcessor, + LogitsProcessor, + LogitsProcessorList, + LogitsWarper, + MinLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + ) + from .generation_utils import top_k_top_p_filtering + from .modeling_utils import Conv1D, PreTrainedModel, apply_chunking_to_forward, prune_layer + from .models.albert import ( + ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + AlbertForMaskedLM, + AlbertForMultipleChoice, + AlbertForPreTraining, + AlbertForQuestionAnswering, + AlbertForSequenceClassification, + AlbertForTokenClassification, + AlbertModel, + AlbertPreTrainedModel, + load_tf_weights_in_albert, + ) + from .models.auto import ( + MODEL_FOR_CAUSAL_LM_MAPPING, + MODEL_FOR_MASKED_LM_MAPPING, + MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + MODEL_FOR_PRETRAINING_MAPPING, + MODEL_FOR_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING, + MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + MODEL_MAPPING, + MODEL_WITH_LM_HEAD_MAPPING, + AutoModel, + AutoModelForCausalLM, + AutoModelForMaskedLM, + AutoModelForMultipleChoice, + AutoModelForNextSentencePrediction, + AutoModelForPreTraining, + AutoModelForQuestionAnswering, + AutoModelForSeq2SeqLM, + AutoModelForSequenceClassification, + AutoModelForTableQuestionAnswering, + AutoModelForTokenClassification, + AutoModelWithLMHead, + ) + from .models.bart import ( + BART_PRETRAINED_MODEL_ARCHIVE_LIST, + BartForConditionalGeneration, + BartForQuestionAnswering, + BartForSequenceClassification, + BartModel, + BartPretrainedModel, + PretrainedBartModel, + ) + from .models.bert import ( + BERT_PRETRAINED_MODEL_ARCHIVE_LIST, + BertForMaskedLM, + BertForMultipleChoice, + BertForNextSentencePrediction, + BertForPreTraining, + BertForQuestionAnswering, + BertForSequenceClassification, + BertForTokenClassification, + BertLayer, + BertLMHeadModel, + BertModel, + BertPreTrainedModel, + load_tf_weights_in_bert, + ) + from .models.bert_generation import ( + BertGenerationDecoder, + BertGenerationEncoder, + load_tf_weights_in_bert_generation, + ) + from .models.blenderbot import ( + BLENDERBOT_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotForConditionalGeneration, + BlenderbotModel, + ) + from .models.blenderbot_small import ( + BLENDERBOT_SMALL_PRETRAINED_MODEL_ARCHIVE_LIST, + BlenderbotSmallForConditionalGeneration, + BlenderbotSmallModel, + ) + from .models.camembert import ( + CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + CamembertForCausalLM, + CamembertForMaskedLM, + CamembertForMultipleChoice, + CamembertForQuestionAnswering, + CamembertForSequenceClassification, + CamembertForTokenClassification, + CamembertModel, + ) + from .models.ctrl import ( + CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + CTRLForSequenceClassification, + CTRLLMHeadModel, + CTRLModel, + CTRLPreTrainedModel, + ) + from .models.deberta import ( + DEBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + DebertaForSequenceClassification, + DebertaModel, + DebertaPreTrainedModel, + ) + from .models.distilbert import ( + DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + DistilBertForMaskedLM, + DistilBertForMultipleChoice, + DistilBertForQuestionAnswering, + DistilBertForSequenceClassification, + DistilBertForTokenClassification, + DistilBertModel, + DistilBertPreTrainedModel, + ) + from .models.dpr import ( + DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, + DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, + DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, + DPRContextEncoder, + DPRPretrainedContextEncoder, + DPRPretrainedQuestionEncoder, + DPRPretrainedReader, + DPRQuestionEncoder, + DPRReader, + ) + from .models.electra import ( + ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + ElectraForMaskedLM, + ElectraForMultipleChoice, + ElectraForPreTraining, + ElectraForQuestionAnswering, + ElectraForSequenceClassification, + ElectraForTokenClassification, + ElectraModel, + ElectraPreTrainedModel, + load_tf_weights_in_electra, + ) + from .models.encoder_decoder import EncoderDecoderModel + from .models.flaubert import ( + FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + FlaubertForMultipleChoice, + FlaubertForQuestionAnswering, + FlaubertForQuestionAnsweringSimple, + FlaubertForSequenceClassification, + FlaubertForTokenClassification, + FlaubertModel, + FlaubertWithLMHeadModel, + ) + from .models.fsmt import FSMTForConditionalGeneration, FSMTModel, PretrainedFSMTModel + from .models.funnel import ( + FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST, + FunnelBaseModel, + FunnelForMaskedLM, + FunnelForMultipleChoice, + FunnelForPreTraining, + FunnelForQuestionAnswering, + FunnelForSequenceClassification, + FunnelForTokenClassification, + FunnelModel, + load_tf_weights_in_funnel, + ) + from .models.gpt2 import ( + GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, + GPT2DoubleHeadsModel, + GPT2ForSequenceClassification, + GPT2LMHeadModel, + GPT2Model, + GPT2PreTrainedModel, + load_tf_weights_in_gpt2, + ) + from .models.layoutlm import ( + LAYOUTLM_PRETRAINED_MODEL_ARCHIVE_LIST, + LayoutLMForMaskedLM, + LayoutLMForTokenClassification, + LayoutLMModel, + ) + from .models.led import ( + LED_PRETRAINED_MODEL_ARCHIVE_LIST, + LEDForConditionalGeneration, + LEDForQuestionAnswering, + LEDForSequenceClassification, + LEDModel, + ) + from .models.longformer import ( + LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + LongformerForMaskedLM, + LongformerForMultipleChoice, + LongformerForQuestionAnswering, + LongformerForSequenceClassification, + LongformerForTokenClassification, + LongformerModel, + LongformerSelfAttention, + ) + from .models.lxmert import ( + LxmertEncoder, + LxmertForPreTraining, + LxmertForQuestionAnswering, + LxmertModel, + LxmertPreTrainedModel, + LxmertVisualFeatureEncoder, + LxmertXLayer, + ) + from .models.marian import MarianModel, MarianMTModel + from .models.mbart import ( + MBartForConditionalGeneration, + MBartForQuestionAnswering, + MBartForSequenceClassification, + MBartModel, + ) + from .models.mmbt import MMBTForClassification, MMBTModel, ModalEmbeddings + from .models.mobilebert import ( + MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + MobileBertForMaskedLM, + MobileBertForMultipleChoice, + MobileBertForNextSentencePrediction, + MobileBertForPreTraining, + MobileBertForQuestionAnswering, + MobileBertForSequenceClassification, + MobileBertForTokenClassification, + MobileBertLayer, + MobileBertModel, + MobileBertPreTrainedModel, + load_tf_weights_in_mobilebert, + ) + from .models.mpnet import ( + MPNET_PRETRAINED_MODEL_ARCHIVE_LIST, + MPNetForMaskedLM, + MPNetForMultipleChoice, + MPNetForQuestionAnswering, + MPNetForSequenceClassification, + MPNetForTokenClassification, + MPNetLayer, + MPNetModel, + MPNetPreTrainedModel, + ) + from .models.mt5 import MT5EncoderModel, MT5ForConditionalGeneration, MT5Model + from .models.openai import ( + OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, + OpenAIGPTDoubleHeadsModel, + OpenAIGPTForSequenceClassification, + OpenAIGPTLMHeadModel, + OpenAIGPTModel, + OpenAIGPTPreTrainedModel, + load_tf_weights_in_openai_gpt, + ) + from .models.pegasus import PegasusForConditionalGeneration, PegasusModel + from .models.prophetnet import ( + PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, + ProphetNetDecoder, + ProphetNetEncoder, + ProphetNetForCausalLM, + ProphetNetForConditionalGeneration, + ProphetNetModel, + ProphetNetPreTrainedModel, + ) + from .models.rag import RagModel, RagSequenceForGeneration, RagTokenForGeneration + from .models.reformer import ( + REFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + ReformerAttention, + ReformerForMaskedLM, + ReformerForQuestionAnswering, + ReformerForSequenceClassification, + ReformerLayer, + ReformerModel, + ReformerModelWithLMHead, + ) + from .models.retribert import RETRIBERT_PRETRAINED_MODEL_ARCHIVE_LIST, RetriBertModel, RetriBertPreTrainedModel + from .models.roberta import ( + ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + RobertaForCausalLM, + RobertaForMaskedLM, + RobertaForMultipleChoice, + RobertaForQuestionAnswering, + RobertaForSequenceClassification, + RobertaForTokenClassification, + RobertaModel, + ) + from .models.squeezebert import ( + SQUEEZEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + SqueezeBertForMaskedLM, + SqueezeBertForMultipleChoice, + SqueezeBertForQuestionAnswering, + SqueezeBertForSequenceClassification, + SqueezeBertForTokenClassification, + SqueezeBertModel, + SqueezeBertModule, + SqueezeBertPreTrainedModel, + ) + from .models.t5 import ( + T5_PRETRAINED_MODEL_ARCHIVE_LIST, + T5EncoderModel, + T5ForConditionalGeneration, + T5Model, + T5PreTrainedModel, + load_tf_weights_in_t5, + ) + from .models.tapas import ( + TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, + TapasForMaskedLM, + TapasForQuestionAnswering, + TapasForSequenceClassification, + TapasModel, + ) + from .models.transfo_xl import ( + TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, + AdaptiveEmbedding, + TransfoXLForSequenceClassification, + TransfoXLLMHeadModel, + TransfoXLModel, + TransfoXLPreTrainedModel, + load_tf_weights_in_transfo_xl, + ) + from .models.xlm import ( + XLM_PRETRAINED_MODEL_ARCHIVE_LIST, + XLMForMultipleChoice, + XLMForQuestionAnswering, + XLMForQuestionAnsweringSimple, + XLMForSequenceClassification, + XLMForTokenClassification, + XLMModel, + XLMPreTrainedModel, + XLMWithLMHeadModel, + ) + from .models.xlm_prophetnet import ( + XLM_PROPHETNET_PRETRAINED_MODEL_ARCHIVE_LIST, + XLMProphetNetDecoder, + XLMProphetNetEncoder, + XLMProphetNetForCausalLM, + XLMProphetNetForConditionalGeneration, + XLMProphetNetModel, + ) + from .models.xlm_roberta import ( + XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + XLMRobertaForCausalLM, + XLMRobertaForMaskedLM, + XLMRobertaForMultipleChoice, + XLMRobertaForQuestionAnswering, + XLMRobertaForSequenceClassification, + XLMRobertaForTokenClassification, + XLMRobertaModel, + ) + from .models.xlnet import ( + XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, + XLNetForMultipleChoice, + XLNetForQuestionAnswering, + XLNetForQuestionAnsweringSimple, + XLNetForSequenceClassification, + XLNetForTokenClassification, + XLNetLMHeadModel, + XLNetModel, + XLNetPreTrainedModel, + load_tf_weights_in_xlnet, + ) + + # Optimization + from .optimization import ( + Adafactor, + AdamW, + get_constant_schedule, + get_constant_schedule_with_warmup, + get_cosine_schedule_with_warmup, + get_cosine_with_hard_restarts_schedule_with_warmup, + get_linear_schedule_with_warmup, + get_polynomial_decay_schedule_with_warmup, + get_scheduler, + ) + + # Trainer + from .trainer import Trainer + from .trainer_pt_utils import torch_distributed_zero_first + from .trainer_seq2seq import Seq2SeqTrainer + else: + from .utils.dummy_pt_objects import * + + # TensorFlow + if is_tf_available(): + + from .benchmark.benchmark_args_tf import TensorFlowBenchmarkArguments + + # Benchmarks + from .benchmark.benchmark_tf import TensorFlowBenchmark + from .generation_tf_utils import tf_top_k_top_p_filtering + from .modeling_tf_utils import TFPreTrainedModel, TFSequenceSummary, TFSharedEmbeddings, shape_list + from .models.albert import ( + TF_ALBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFAlbertForMaskedLM, + TFAlbertForMultipleChoice, + TFAlbertForPreTraining, + TFAlbertForQuestionAnswering, + TFAlbertForSequenceClassification, + TFAlbertForTokenClassification, + TFAlbertMainLayer, + TFAlbertModel, + TFAlbertPreTrainedModel, + ) + from .models.auto import ( + TF_MODEL_FOR_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_MASKED_LM_MAPPING, + TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING, + TF_MODEL_FOR_NEXT_SENTENCE_PREDICTION_MAPPING, + TF_MODEL_FOR_PRETRAINING_MAPPING, + TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING, + TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING, + TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING, + TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING, + TF_MODEL_MAPPING, + TF_MODEL_WITH_LM_HEAD_MAPPING, + TFAutoModel, + TFAutoModelForCausalLM, + TFAutoModelForMaskedLM, + TFAutoModelForMultipleChoice, + TFAutoModelForPreTraining, + TFAutoModelForQuestionAnswering, + TFAutoModelForSeq2SeqLM, + TFAutoModelForSequenceClassification, + TFAutoModelForTokenClassification, + TFAutoModelWithLMHead, + ) + from .models.bart import TFBartForConditionalGeneration, TFBartModel, TFBartPretrainedModel + from .models.bert import ( + TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFBertEmbeddings, + TFBertForMaskedLM, + TFBertForMultipleChoice, + TFBertForNextSentencePrediction, + TFBertForPreTraining, + TFBertForQuestionAnswering, + TFBertForSequenceClassification, + TFBertForTokenClassification, + TFBertLMHeadModel, + TFBertMainLayer, + TFBertModel, + TFBertPreTrainedModel, + ) + from .models.blenderbot import TFBlenderbotForConditionalGeneration + from .models.camembert import ( + TF_CAMEMBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCamembertForMaskedLM, + TFCamembertForMultipleChoice, + TFCamembertForQuestionAnswering, + TFCamembertForSequenceClassification, + TFCamembertForTokenClassification, + TFCamembertModel, + ) + from .models.ctrl import ( + TF_CTRL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFCTRLForSequenceClassification, + TFCTRLLMHeadModel, + TFCTRLModel, + TFCTRLPreTrainedModel, + ) + from .models.distilbert import ( + TF_DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFDistilBertForMaskedLM, + TFDistilBertForMultipleChoice, + TFDistilBertForQuestionAnswering, + TFDistilBertForSequenceClassification, + TFDistilBertForTokenClassification, + TFDistilBertMainLayer, + TFDistilBertModel, + TFDistilBertPreTrainedModel, + ) + from .models.dpr import ( + TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, + TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST, + TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFDPRContextEncoder, + TFDPRPretrainedContextEncoder, + TFDPRPretrainedQuestionEncoder, + TFDPRPretrainedReader, + TFDPRQuestionEncoder, + TFDPRReader, + ) + from .models.electra import ( + TF_ELECTRA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFElectraForMaskedLM, + TFElectraForMultipleChoice, + TFElectraForPreTraining, + TFElectraForQuestionAnswering, + TFElectraForSequenceClassification, + TFElectraForTokenClassification, + TFElectraModel, + TFElectraPreTrainedModel, + ) + from .models.flaubert import ( + TF_FLAUBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFFlaubertForMultipleChoice, + TFFlaubertForQuestionAnsweringSimple, + TFFlaubertForSequenceClassification, + TFFlaubertForTokenClassification, + TFFlaubertModel, + TFFlaubertWithLMHeadModel, + ) + from .models.funnel import ( + TF_FUNNEL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFFunnelBaseModel, + TFFunnelForMaskedLM, + TFFunnelForMultipleChoice, + TFFunnelForPreTraining, + TFFunnelForQuestionAnswering, + TFFunnelForSequenceClassification, + TFFunnelForTokenClassification, + TFFunnelModel, + ) + from .models.gpt2 import ( + TF_GPT2_PRETRAINED_MODEL_ARCHIVE_LIST, + TFGPT2DoubleHeadsModel, + TFGPT2ForSequenceClassification, + TFGPT2LMHeadModel, + TFGPT2MainLayer, + TFGPT2Model, + TFGPT2PreTrainedModel, + ) + from .models.led import TFLEDForConditionalGeneration, TFLEDModel, TFLEDPreTrainedModel + from .models.longformer import ( + TF_LONGFORMER_PRETRAINED_MODEL_ARCHIVE_LIST, + TFLongformerForMaskedLM, + TFLongformerForMultipleChoice, + TFLongformerForQuestionAnswering, + TFLongformerForSequenceClassification, + TFLongformerForTokenClassification, + TFLongformerModel, + TFLongformerSelfAttention, + ) + from .models.lxmert import ( + TF_LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFLxmertForPreTraining, + TFLxmertMainLayer, + TFLxmertModel, + TFLxmertPreTrainedModel, + TFLxmertVisualFeatureEncoder, + ) + from .models.marian import TFMarianMTModel + from .models.mbart import TFMBartForConditionalGeneration + from .models.mobilebert import ( + TF_MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFMobileBertForMaskedLM, + TFMobileBertForMultipleChoice, + TFMobileBertForNextSentencePrediction, + TFMobileBertForPreTraining, + TFMobileBertForQuestionAnswering, + TFMobileBertForSequenceClassification, + TFMobileBertForTokenClassification, + TFMobileBertMainLayer, + TFMobileBertModel, + TFMobileBertPreTrainedModel, + ) + from .models.mpnet import ( + TF_MPNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFMPNetForMaskedLM, + TFMPNetForMultipleChoice, + TFMPNetForQuestionAnswering, + TFMPNetForSequenceClassification, + TFMPNetForTokenClassification, + TFMPNetMainLayer, + TFMPNetModel, + TFMPNetPreTrainedModel, + ) + from .models.mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model + from .models.openai import ( + TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, + TFOpenAIGPTDoubleHeadsModel, + TFOpenAIGPTForSequenceClassification, + TFOpenAIGPTLMHeadModel, + TFOpenAIGPTMainLayer, + TFOpenAIGPTModel, + TFOpenAIGPTPreTrainedModel, + ) + from .models.pegasus import TFPegasusForConditionalGeneration + from .models.roberta import ( + TF_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFRobertaForMaskedLM, + TFRobertaForMultipleChoice, + TFRobertaForQuestionAnswering, + TFRobertaForSequenceClassification, + TFRobertaForTokenClassification, + TFRobertaMainLayer, + TFRobertaModel, + TFRobertaPreTrainedModel, + ) + from .models.t5 import ( + TF_T5_PRETRAINED_MODEL_ARCHIVE_LIST, + TFT5EncoderModel, + TFT5ForConditionalGeneration, + TFT5Model, + TFT5PreTrainedModel, + ) + from .models.transfo_xl import ( + TF_TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_LIST, + TFAdaptiveEmbedding, + TFTransfoXLForSequenceClassification, + TFTransfoXLLMHeadModel, + TFTransfoXLMainLayer, + TFTransfoXLModel, + TFTransfoXLPreTrainedModel, + ) + from .models.xlm import ( + TF_XLM_PRETRAINED_MODEL_ARCHIVE_LIST, + TFXLMForMultipleChoice, + TFXLMForQuestionAnsweringSimple, + TFXLMForSequenceClassification, + TFXLMForTokenClassification, + TFXLMMainLayer, + TFXLMModel, + TFXLMPreTrainedModel, + TFXLMWithLMHeadModel, + ) + from .models.xlm_roberta import ( + TF_XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST, + TFXLMRobertaForMaskedLM, + TFXLMRobertaForMultipleChoice, + TFXLMRobertaForQuestionAnswering, + TFXLMRobertaForSequenceClassification, + TFXLMRobertaForTokenClassification, + TFXLMRobertaModel, + ) + from .models.xlnet import ( + TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST, + TFXLNetForMultipleChoice, + TFXLNetForQuestionAnsweringSimple, + TFXLNetForSequenceClassification, + TFXLNetForTokenClassification, + TFXLNetLMHeadModel, + TFXLNetMainLayer, + TFXLNetModel, + TFXLNetPreTrainedModel, + ) + + # Optimization + from .optimization_tf import AdamWeightDecay, GradientAccumulator, WarmUp, create_optimizer + + # Trainer + from .trainer_tf import TFTrainer + + else: + # Import the same objects as dummies to get them in the namespace. + # They will raise an import error if the user tries to instantiate / use them. + from .utils.dummy_tf_objects import * + + if is_flax_available(): + from .modeling_flax_utils import FlaxPreTrainedModel + from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel + from .models.bert import FlaxBertForMaskedLM, FlaxBertModel + from .models.roberta import FlaxRobertaModel + else: + # Import the same objects as dummies to get them in the namespace. + # They will raise an import error if the user tries to instantiate / use them. + from .utils.dummy_flax_objects import * else: - # Import the same objects as dummies to get them in the namespace. - # They will raise an import error if the user tries to instantiate / use them. - from .utils.dummy_tf_objects import * + import importlib + import os + import sys + class _LazyModule(_BaseLazyModule): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ -if is_flax_available(): - from .modeling_flax_utils import FlaxPreTrainedModel - from .models.auto import FLAX_MODEL_MAPPING, FlaxAutoModel - from .models.bert import FlaxBertForMaskedLM, FlaxBertModel - from .models.roberta import FlaxRobertaModel -else: - # Import the same objects as dummies to get them in the namespace. - # They will raise an import error if the user tries to instantiate / use them. - from .utils.dummy_flax_objects import * + __file__ = globals()["__file__"] + __path__ = [os.path.dirname(__file__)] + + def _get_module(self, module_name: str): + return importlib.import_module("." + module_name, self.__name__) + + def __getattr__(self, name: str): + # Special handling for the version, which is a constant from this module and not imported in a submodule. + if name == "__version__": + return __version__ + return super().__getattr__(name) + + sys.modules[__name__] = _LazyModule(__name__, _import_structure) if not is_tf_available() and not is_torch_available() and not is_flax_available(): diff --git a/src/transformers/benchmark/benchmark_utils.py b/src/transformers/benchmark/benchmark_utils.py index 95b7921140..5b054614c3 100644 --- a/src/transformers/benchmark/benchmark_utils.py +++ b/src/transformers/benchmark/benchmark_utils.py @@ -30,9 +30,8 @@ from multiprocessing import Pipe, Process, Queue from multiprocessing.connection import Connection from typing import Callable, Iterable, List, NamedTuple, Optional, Union -from transformers import AutoConfig, PretrainedConfig -from transformers import __version__ as version - +from .. import AutoConfig, PretrainedConfig +from .. import __version__ as version from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available from ..utils import logging from .benchmark_args_utils import BenchmarkArguments diff --git a/src/transformers/commands/add_new_model.py b/src/transformers/commands/add_new_model.py index 6b27e0b24c..66dc7201e6 100644 --- a/src/transformers/commands/add_new_model.py +++ b/src/transformers/commands/add_new_model.py @@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace from pathlib import Path from typing import List -from transformers.commands import BaseTransformersCLICommand - from ..utils import logging +from . import BaseTransformersCLICommand try: diff --git a/src/transformers/commands/convert.py b/src/transformers/commands/convert.py index 965c0a91c8..30767f26f9 100644 --- a/src/transformers/commands/convert.py +++ b/src/transformers/commands/convert.py @@ -14,9 +14,8 @@ from argparse import ArgumentParser, Namespace -from transformers.commands import BaseTransformersCLICommand - from ..utils import logging +from . import BaseTransformersCLICommand def convert_command_factory(args: Namespace): @@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand): def run(self): if self._model_type == "albert": try: - from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import ( + from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import ( convert_tf_checkpoint_to_pytorch, ) except ImportError: @@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand): convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "bert": try: - from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import ( + from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import ( convert_tf_checkpoint_to_pytorch, ) except ImportError: @@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand): convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "funnel": try: - from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import ( + from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import ( convert_tf_checkpoint_to_pytorch, ) except ImportError: @@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand): convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "gpt": - from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( + from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( convert_openai_checkpoint_to_pytorch, ) convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "transfo_xl": try: - from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import ( + from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import ( convert_transfo_xl_checkpoint_to_pytorch, ) except ImportError: @@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand): ) elif self._model_type == "gpt2": try: - from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import ( + from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import ( convert_gpt2_checkpoint_to_pytorch, ) except ImportError: @@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand): convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) elif self._model_type == "xlnet": try: - from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import ( + from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import ( convert_xlnet_checkpoint_to_pytorch, ) except ImportError: @@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand): self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name ) elif self._model_type == "xlm": - from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import ( + from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import ( convert_xlm_checkpoint_to_pytorch, ) convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) elif self._model_type == "lxmert": - from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import ( + from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import ( convert_lxmert_checkpoint_to_pytorch, ) diff --git a/src/transformers/commands/download.py b/src/transformers/commands/download.py index b4953b6e5d..3c224555df 100644 --- a/src/transformers/commands/download.py +++ b/src/transformers/commands/download.py @@ -14,7 +14,7 @@ from argparse import ArgumentParser -from transformers.commands import BaseTransformersCLICommand +from . import BaseTransformersCLICommand def download_command_factory(args): @@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand): self._force = force def run(self): - from transformers import AutoModel, AutoTokenizer + from ..models.auto import AutoModel, AutoTokenizer AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) diff --git a/src/transformers/commands/env.py b/src/transformers/commands/env.py index edc74d91ff..beee192ab4 100644 --- a/src/transformers/commands/env.py +++ b/src/transformers/commands/env.py @@ -15,9 +15,9 @@ import platform from argparse import ArgumentParser -from transformers import __version__ as version -from transformers import is_tf_available, is_torch_available -from transformers.commands import BaseTransformersCLICommand +from .. import __version__ as version +from ..file_utils import is_tf_available, is_torch_available +from . import BaseTransformersCLICommand def info_command_factory(_): diff --git a/src/transformers/commands/lfs.py b/src/transformers/commands/lfs.py index 5ca2c2cace..42b00f0d2f 100644 --- a/src/transformers/commands/lfs.py +++ b/src/transformers/commands/lfs.py @@ -25,9 +25,9 @@ from contextlib import AbstractContextManager from typing import Dict, List, Optional import requests -from transformers.commands import BaseTransformersCLICommand from ..utils import logging +from . import BaseTransformersCLICommand logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/transformers/commands/run.py b/src/transformers/commands/run.py index 5d2eb77cc0..768b90007a 100644 --- a/src/transformers/commands/run.py +++ b/src/transformers/commands/run.py @@ -14,10 +14,9 @@ from argparse import ArgumentParser -from transformers.commands import BaseTransformersCLICommand -from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline - +from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline from ..utils import logging +from . import BaseTransformersCLICommand logger = logging.get_logger(__name__) # pylint: disable=invalid-name diff --git a/src/transformers/commands/serving.py b/src/transformers/commands/serving.py index 169b82d38b..7bef8d5eeb 100644 --- a/src/transformers/commands/serving.py +++ b/src/transformers/commands/serving.py @@ -15,11 +15,9 @@ from argparse import ArgumentParser, Namespace from typing import Any, List, Optional -from transformers import Pipeline -from transformers.commands import BaseTransformersCLICommand -from transformers.pipelines import SUPPORTED_TASKS, pipeline - +from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline from ..utils import logging +from . import BaseTransformersCLICommand try: diff --git a/src/transformers/commands/train.py b/src/transformers/commands/train.py index 2d7b2d6e4b..a2d3029221 100644 --- a/src/transformers/commands/train.py +++ b/src/transformers/commands/train.py @@ -15,11 +15,11 @@ import os from argparse import ArgumentParser, Namespace -from transformers import SingleSentenceClassificationProcessor as Processor -from transformers import TextClassificationPipeline, is_tf_available, is_torch_available -from transformers.commands import BaseTransformersCLICommand - +from ..data import SingleSentenceClassificationProcessor as Processor +from ..file_utils import is_tf_available, is_torch_available +from ..pipelines import TextClassificationPipeline from ..utils import logging +from . import BaseTransformersCLICommand if not is_tf_available() and not is_torch_available(): diff --git a/src/transformers/commands/transformers_cli.py b/src/transformers/commands/transformers_cli.py index 7940c38787..d63f6bc9c6 100644 --- a/src/transformers/commands/transformers_cli.py +++ b/src/transformers/commands/transformers_cli.py @@ -15,14 +15,14 @@ from argparse import ArgumentParser -from transformers.commands.add_new_model import AddNewModelCommand -from transformers.commands.convert import ConvertCommand -from transformers.commands.download import DownloadCommand -from transformers.commands.env import EnvironmentCommand -from transformers.commands.lfs import LfsCommands -from transformers.commands.run import RunCommand -from transformers.commands.serving import ServeCommand -from transformers.commands.user import UserCommands +from .add_new_model import AddNewModelCommand +from .convert import ConvertCommand +from .download import DownloadCommand +from .env import EnvironmentCommand +from .lfs import LfsCommands +from .run import RunCommand +from .serving import ServeCommand +from .user import UserCommands def main(): diff --git a/src/transformers/commands/user.py b/src/transformers/commands/user.py index d321a5a11a..9a16dec22b 100644 --- a/src/transformers/commands/user.py +++ b/src/transformers/commands/user.py @@ -20,8 +20,9 @@ from getpass import getpass from typing import List, Union from requests.exceptions import HTTPError -from transformers.commands import BaseTransformersCLICommand -from transformers.hf_api import HfApi, HfFolder + +from ..hf_api import HfApi, HfFolder +from . import BaseTransformersCLICommand UPLOAD_MAX_FILES = 15 diff --git a/src/transformers/convert_graph_to_onnx.py b/src/transformers/convert_graph_to_onnx.py index 0446a3533d..25ca790c18 100644 --- a/src/transformers/convert_graph_to_onnx.py +++ b/src/transformers/convert_graph_to_onnx.py @@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple from packaging.version import Version, parse -from transformers import is_tf_available, is_torch_available -from transformers.file_utils import ModelOutput -from transformers.pipelines import Pipeline, pipeline -from transformers.tokenization_utils import BatchEncoding +from .file_utils import ModelOutput, is_tf_available, is_torch_available +from .pipelines import Pipeline, pipeline +from .tokenization_utils import BatchEncoding # This is the minimal required version to diff --git a/src/transformers/convert_pytorch_checkpoint_to_tf2.py b/src/transformers/convert_pytorch_checkpoint_to_tf2.py index 5447ede65e..4c21456d21 100755 --- a/src/transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/src/transformers/convert_pytorch_checkpoint_to_tf2.py @@ -18,7 +18,7 @@ import argparse import os -from transformers import ( +from . import ( ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BART_PRETRAINED_MODEL_ARCHIVE_LIST, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, @@ -87,15 +87,15 @@ from transformers import ( is_torch_available, load_pytorch_checkpoint_in_tf2_model, ) -from transformers.file_utils import hf_bucket_url -from transformers.utils import logging +from .file_utils import hf_bucket_url +from .utils import logging if is_torch_available(): import numpy as np import torch - from transformers import ( + from . import ( AlbertForPreTraining, BartForConditionalGeneration, BertForPreTraining, diff --git a/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py b/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py index 631d57df26..d78608633e 100755 --- a/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py +++ b/src/transformers/convert_slow_tokenizers_checkpoints_to_fast.py @@ -18,8 +18,9 @@ import argparse import os import transformers -from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS -from transformers.utils import logging + +from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS +from .utils import logging logging.set_verbosity_info() diff --git a/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py b/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py index 3dbb8a3646..5707a09977 100755 --- a/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py +++ b/src/transformers/convert_tf_hub_seq_to_seq_bert_to_pytorch.py @@ -17,7 +17,7 @@ import argparse -from transformers import ( +from . import ( BertConfig, BertGenerationConfig, BertGenerationDecoder, diff --git a/src/transformers/data/metrics/squad_metrics.py b/src/transformers/data/metrics/squad_metrics.py index 20c084d548..94ce573f75 100644 --- a/src/transformers/data/metrics/squad_metrics.py +++ b/src/transformers/data/metrics/squad_metrics.py @@ -27,8 +27,7 @@ import math import re import string -from transformers import BasicTokenizer - +from ...models.bert import BasicTokenizer from ...utils import logging diff --git a/src/transformers/data/test_generation_utils.py b/src/transformers/data/test_generation_utils.py index 32c723f4cb..ae2f7ccc92 100644 --- a/src/transformers/data/test_generation_utils.py +++ b/src/transformers/data/test_generation_utils.py @@ -17,15 +17,14 @@ import unittest import timeout_decorator -from transformers import is_torch_available -from transformers.file_utils import cached_property -from transformers.testing_utils import require_torch +from ..file_utils import cached_property, is_torch_available +from ..testing_utils import require_torch if is_torch_available(): import torch - from transformers import MarianConfig, MarianMTModel + from ..models.marian import MarianConfig, MarianMTModel @require_torch diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index 1f765c1a67..413409cded 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -33,6 +33,7 @@ from dataclasses import fields from functools import partial, wraps from hashlib import sha256 from pathlib import Path +from types import ModuleType from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from urllib.parse import urlparse from zipfile import ZipFile, is_zipfile @@ -41,7 +42,6 @@ import numpy as np from packaging import version from tqdm.auto import tqdm -import importlib_metadata import requests from filelock import FileLock @@ -50,6 +50,13 @@ from .hf_api import HfFolder from .utils import logging +# The package importlib_metadata is in a different place, depending on the python version. +if version.parse(sys.version) < version.parse("3.8"): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + logger = logging.get_logger(__name__) # pylint: disable=invalid-name ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"} @@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError: _scatter_available = importlib.util.find_spec("torch_scatter") is not None try: - _scatter_version = importlib_metadata.version("torch_scatterr") + _scatter_version = importlib_metadata.version("torch_scatter") logger.debug(f"Successfully imported torch-scatter version {_scatter_version}") except importlib_metadata.PackageNotFoundError: _scatter_available = False @@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict): Convert self to a tuple containing all the attributes/keys that are not ``None``. """ return tuple(self[k] for k in self.keys()) + + +class _BaseLazyModule(ModuleType): + """ + Module class that surfaces all objects but only performs associated imports when the objects are requested. + """ + + # Very heavily inspired by optuna.integration._IntegrationModule + # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py + def __init__(self, name, import_structure): + super().__init__(name) + self._modules = set(import_structure.keys()) + self._class_to_module = {} + for key, values in import_structure.items(): + for value in values: + self._class_to_module[value] = key + # Needed for autocompletion in an IDE + self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), []) + + # Needed for autocompletion in an IDE + def __dir__(self): + return super().__dir__() + self.__all__ + + def __getattr__(self, name: str) -> Any: + if name in self._modules: + value = self._get_module(name) + elif name in self._class_to_module.keys(): + module = self._get_module(self._class_to_module[name]) + value = getattr(module, name) + else: + raise AttributeError(f"module {self.__name__} has no attribute {name}") + + setattr(self, name, value) + return value + + def _get_module(self, module_name: str) -> ModuleType: + raise NotImplementedError diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index db97827b04..0c541b846d 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -29,7 +29,7 @@ logger = logging.get_logger(__name__) # comet_ml requires to be imported before any ML frameworks -_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED" +_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED" if _has_comet: try: import comet_ml # noqa: F401 diff --git a/src/transformers/models/__init__.py b/src/transformers/models/__init__.py index e69de29bb2..32fdd4c5e3 100644 --- a/src/transformers/models/__init__.py +++ b/src/transformers/models/__init__.py @@ -0,0 +1,68 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import ( + albert, + auto, + bart, + barthez, + bert, + bert_generation, + bert_japanese, + bertweet, + blenderbot, + blenderbot_small, + camembert, + ctrl, + deberta, + dialogpt, + distilbert, + dpr, + electra, + encoder_decoder, + flaubert, + fsmt, + funnel, + gpt2, + herbert, + layoutlm, + led, + longformer, + lxmert, + marian, + mbart, + mmbt, + mobilebert, + mpnet, + mt5, + openai, + pegasus, + phobert, + prophetnet, + rag, + reformer, + retribert, + roberta, + squeezebert, + t5, + tapas, + transfo_xl, + xlm, + xlm_roberta, + xlnet, +) diff --git a/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py index 10c018170f..7c5e04cc1d 100644 --- a/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/albert/convert_albert_original_tf_checkpoint_to_pytorch.py @@ -19,8 +19,8 @@ import argparse import torch -from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert -from transformers.utils import logging +from ...utils import logging +from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert logging.set_verbosity_info() diff --git a/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py index 8978b8b2e5..61c36702a5 100644 --- a/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/bart/convert_bart_original_pytorch_checkpoint_to_pytorch.py @@ -23,15 +23,9 @@ import fairseq import torch from packaging import version -from transformers import ( - BartConfig, - BartForConditionalGeneration, - BartForSequenceClassification, - BartModel, - BartTokenizer, -) -from transformers.models.bart.modeling_bart import _make_linear_from_emb -from transformers.utils import logging +from ...utils import logging +from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer +from .modeling_bart import _make_linear_from_emb FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] diff --git a/src/transformers/models/bart/tokenization_bart.py b/src/transformers/models/bart/tokenization_bart.py index 6b46e30e9d..57a8e2448b 100644 --- a/src/transformers/models/bart/tokenization_bart.py +++ b/src/transformers/models/bart/tokenization_bart.py @@ -15,8 +15,7 @@ from typing import List, Optional -from transformers import add_start_docstrings - +from ...file_utils import add_start_docstrings from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging from ..roberta.tokenization_roberta import RobertaTokenizer diff --git a/src/transformers/models/bart/tokenization_bart_fast.py b/src/transformers/models/bart/tokenization_bart_fast.py index 30b77275f2..19678f9d52 100644 --- a/src/transformers/models/bart/tokenization_bart_fast.py +++ b/src/transformers/models/bart/tokenization_bart_fast.py @@ -15,8 +15,7 @@ from typing import List, Optional -from transformers import add_start_docstrings - +from ...file_utils import add_start_docstrings from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...utils import logging from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast diff --git a/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py index c780c0f835..f343fec275 100644 --- a/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py +++ b/src/transformers/models/bert/convert_bert_original_tf2_checkpoint_to_pytorch.py @@ -28,8 +28,8 @@ import re import tensorflow as tf import torch -from transformers import BertConfig, BertModel -from transformers.utils import logging +from ...utils import logging +from . import BertConfig, BertModel logging.set_verbosity_info() diff --git a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py index d1cb69a2eb..049d50de9c 100755 --- a/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/bert/convert_bert_original_tf_checkpoint_to_pytorch.py @@ -19,8 +19,8 @@ import argparse import torch -from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert -from transformers.utils import logging +from ...utils import logging +from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert logging.set_verbosity_info() diff --git a/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py b/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py index 07685f6450..25aeb7762f 100644 --- a/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py +++ b/src/transformers/models/bert/convert_bert_pytorch_checkpoint_to_original_tf.py @@ -22,7 +22,7 @@ import numpy as np import tensorflow as tf import torch -from transformers import BertModel +from . import BertModel def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): diff --git a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py index d31cf67c1e..7958da85d5 100644 --- a/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/blenderbot/convert_blenderbot_original_pytorch_checkpoint_to_pytorch.py @@ -18,8 +18,8 @@ import argparse import torch -from transformers import BartConfig, BartForConditionalGeneration -from transformers.utils import logging +from ...models.bart import BartConfig, BartForConditionalGeneration +from ...utils import logging logging.set_verbosity_info() diff --git a/src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py index f588a2fde8..00d9aa3359 100644 --- a/src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/dialogpt/convert_dialogpt_original_pytorch_checkpoint_to_pytorch.py @@ -17,7 +17,7 @@ import os import torch -from transformers.file_utils import WEIGHTS_NAME +from ...file_utils import WEIGHTS_NAME DIALOGPT_MODELS = ["small", "medium", "large"] diff --git a/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py b/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py index 7f6c20a07b..a9cd4c1dcb 100644 --- a/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py +++ b/src/transformers/models/dpr/convert_dpr_original_checkpoint_to_pytorch.py @@ -19,7 +19,8 @@ from pathlib import Path import torch from torch.serialization import default_restore_location -from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader +from ...models.bert import BertConfig +from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader CheckpointState = collections.namedtuple( diff --git a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py index 9cbfcf665d..ffe980c6a4 100644 --- a/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/electra/convert_electra_original_tf_checkpoint_to_pytorch.py @@ -19,8 +19,8 @@ import argparse import torch -from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra -from transformers.utils import logging +from ...utils import logging +from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra logging.set_verbosity_info() diff --git a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py index e27650d7dd..922f0b6d04 100755 --- a/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/fsmt/convert_fsmt_original_pytorch_checkpoint_to_pytorch.py @@ -31,9 +31,10 @@ import torch from fairseq import hub_utils from fairseq.data.dictionary import Dictionary -from transformers import WEIGHTS_NAME, logging -from transformers.models.fsmt import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration -from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE +from ...file_utils import WEIGHTS_NAME +from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE +from ...utils import logging +from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration logging.set_verbosity_warning() diff --git a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py index 5d93fc24db..03f461658c 100755 --- a/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/funnel/convert_funnel_original_tf_checkpoint_to_pytorch.py @@ -20,7 +20,7 @@ import logging import torch -from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel +from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel logging.basicConfig(level=logging.INFO) diff --git a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py index e42ebd888d..5f85d0e184 100755 --- a/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/gpt2/convert_gpt2_original_tf_checkpoint_to_pytorch.py @@ -19,8 +19,9 @@ import argparse import torch -from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2 -from transformers.utils import logging +from ...file_utils import CONFIG_NAME, WEIGHTS_NAME +from ...utils import logging +from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2 logging.set_verbosity_info() diff --git a/src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py b/src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py index 6c310a5faf..b479746314 100644 --- a/src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py +++ b/src/transformers/models/longformer/convert_longformer_original_pytorch_lightning_to_pytorch.py @@ -20,7 +20,7 @@ import argparse import pytorch_lightning as pl import torch -from transformers import LongformerForQuestionAnswering, LongformerModel +from . import LongformerForQuestionAnswering, LongformerModel class LightningModel(pl.LightningModule): diff --git a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py index e4125ed566..4034e72024 100755 --- a/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/lxmert/convert_lxmert_original_tf_checkpoint_to_pytorch.py @@ -20,7 +20,7 @@ import logging import torch -from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert +from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert logging.basicConfig(level=logging.INFO) diff --git a/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py b/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py index 0ab653e9a2..7144fc1b2e 100644 --- a/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_tatoeba_to_pytorch.py @@ -17,7 +17,7 @@ import os from pathlib import Path from typing import List, Tuple -from transformers.models.marian.convert_marian_to_pytorch import ( +from .convert_marian_to_pytorch import ( FRONT_MATTER_TEMPLATE, _parse_readme, convert_all_sentencepiece_models, diff --git a/src/transformers/models/marian/convert_marian_to_pytorch.py b/src/transformers/models/marian/convert_marian_to_pytorch.py index a7faef942e..5c47c76d81 100644 --- a/src/transformers/models/marian/convert_marian_to_pytorch.py +++ b/src/transformers/models/marian/convert_marian_to_pytorch.py @@ -26,8 +26,8 @@ import numpy as np import torch from tqdm import tqdm -from transformers import MarianConfig, MarianMTModel, MarianTokenizer -from transformers.hf_api import HfApi +from ...hf_api import HfApi +from . import MarianConfig, MarianMTModel, MarianTokenizer def remove_suffix(text: str, suffix: str): diff --git a/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py b/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py index 46c933d7a4..146a34fd35 100644 --- a/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py +++ b/src/transformers/models/mbart/convert_mbart_original_checkpoint_to_pytorch.py @@ -16,9 +16,9 @@ import argparse import torch -from transformers import BartForConditionalGeneration, MBartConfig - +from ..bart import BartForConditionalGeneration from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_ +from . import MBartConfig def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): diff --git a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py index ce5396a932..5ba8bc493d 100644 --- a/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/mobilebert/convert_mobilebert_original_tf_checkpoint_to_pytorch.py @@ -16,8 +16,8 @@ import argparse import torch -from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert -from transformers.utils import logging +from ...utils import logging +from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert logging.set_verbosity_info() diff --git a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py index 397884e32c..abe06f8ebd 100755 --- a/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/openai/convert_openai_original_tf_checkpoint_to_pytorch.py @@ -19,8 +19,9 @@ import argparse import torch -from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt -from transformers.utils import logging +from ...file_utils import CONFIG_NAME, WEIGHTS_NAME +from ...utils import logging +from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt logging.set_verbosity_info() diff --git a/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py b/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py index 9254a0ba94..c6043f2166 100644 --- a/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py +++ b/src/transformers/models/pegasus/convert_pegasus_tf_to_pytorch.py @@ -22,8 +22,8 @@ import tensorflow as tf import torch from tqdm import tqdm -from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer -from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params +from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer +from .configuration_pegasus import DEFAULTS, task_specific_params PATTERNS = [ diff --git a/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py index cbd8c49956..1f4082cbe6 100644 --- a/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/prophetnet/convert_prophetnet_original_pytorch_checkpoint_to_pytorch.py @@ -19,8 +19,6 @@ import argparse import torch -from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging - # transformers_old should correspond to branch `save_old_prophetnet_model_structure` here # original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively from transformers_old.modeling_prophetnet import ( @@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import ( XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld, ) +from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging + logger = logging.get_logger(__name__) logging.set_verbosity_info() diff --git a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py index ec58e2f913..23adf7e74f 100755 --- a/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py +++ b/src/transformers/models/reformer/convert_reformer_trax_checkpoint_to_pytorch.py @@ -21,8 +21,8 @@ import pickle import numpy as np import torch -from transformers import ReformerConfig, ReformerModelWithLMHead -from transformers.utils import logging +from ...utils import logging +from . import ReformerConfig, ReformerModelWithLMHead logging.set_verbosity_info() diff --git a/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py index 6f12f59c3b..68c8531078 100644 --- a/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/roberta/convert_roberta_original_pytorch_checkpoint_to_pytorch.py @@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel from fairseq.modules import TransformerSentenceEncoderLayer from packaging import version -from transformers.models.bert.modeling_bert import ( - BertIntermediate, - BertLayer, - BertOutput, - BertSelfAttention, - BertSelfOutput, -) -from transformers.models.roberta.modeling_roberta import ( - RobertaConfig, - RobertaForMaskedLM, - RobertaForSequenceClassification, -) -from transformers.utils import logging +from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput +from ...utils import logging +from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification if version.parse(fairseq.__version__) < version.parse("0.9.0"): diff --git a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py index e38680df84..5a702eab92 100755 --- a/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/t5/convert_t5_original_tf_checkpoint_to_pytorch.py @@ -17,8 +17,8 @@ import argparse -from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 -from transformers.utils import logging +from ...utils import logging +from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 logging.set_verbosity_info() diff --git a/src/transformers/models/t5/modeling_tf_t5.py b/src/transformers/models/t5/modeling_tf_t5.py index 0dd1547820..b70976cf39 100644 --- a/src/transformers/models/t5/modeling_tf_t5.py +++ b/src/transformers/models/t5/modeling_tf_t5.py @@ -22,8 +22,6 @@ from typing import Tuple import tensorflow as tf -from transformers.modeling_tf_utils import TFWrappedEmbeddings - from ...activations_tf import get_tf_activation from ...file_utils import ( DUMMY_INPUTS, @@ -42,6 +40,7 @@ from ...modeling_tf_utils import ( TFCausalLanguageModelingLoss, TFPreTrainedModel, TFSharedEmbeddings, + TFWrappedEmbeddings, input_processing, keras_serializable, shape_list, diff --git a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py index 1743fc842a..6033dc8635 100644 --- a/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/tapas/convert_tapas_original_tf_checkpoint_to_pytorch.py @@ -17,16 +17,16 @@ import argparse -from transformers.models.tapas.modeling_tapas import ( +from ...utils import logging +from . import ( TapasConfig, TapasForMaskedLM, TapasForQuestionAnswering, TapasForSequenceClassification, TapasModel, + TapasTokenizer, load_tf_weights_in_tapas, ) -from transformers.models.tapas.tokenization_tapas import TapasTokenizer -from transformers.utils import logging logging.set_verbosity_info() diff --git a/src/transformers/models/tapas/tokenization_tapas.py b/src/transformers/models/tapas/tokenization_tapas.py index e51d8fe9a9..7498b2a35e 100644 --- a/src/transformers/models/tapas/tokenization_tapas.py +++ b/src/transformers/models/tapas/tokenization_tapas.py @@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union import numpy as np -from transformers import add_end_docstrings - -from ...file_utils import is_pandas_available +from ...file_utils import add_end_docstrings, is_pandas_available from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils_base import ( ENCODE_KWARGS_DOCSTRING, diff --git a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py index a5d8e194ce..ea43e10aad 100755 --- a/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/transfo_xl/convert_transfo_xl_original_tf_checkpoint_to_pytorch.py @@ -22,16 +22,11 @@ import sys import torch -import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils -from transformers import ( - CONFIG_NAME, - WEIGHTS_NAME, - TransfoXLConfig, - TransfoXLLMHeadModel, - load_tf_weights_in_transfo_xl, -) -from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES -from transformers.utils import logging +from ...file_utils import CONFIG_NAME, WEIGHTS_NAME +from ...utils import logging +from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl +from . import tokenization_transfo_xl as data_utils +from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES logging.set_verbosity_info() diff --git a/src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py b/src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py index 37ee8a25e8..8e50b5983b 100755 --- a/src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py +++ b/src/transformers/models/xlm/convert_xlm_original_pytorch_checkpoint_to_pytorch.py @@ -21,9 +21,9 @@ import json import numpy import torch -from transformers import CONFIG_NAME, WEIGHTS_NAME -from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES -from transformers.utils import logging +from ...file_utils import CONFIG_NAME, WEIGHTS_NAME +from ...utils import logging +from .tokenization_xlm import VOCAB_FILES_NAMES logging.set_verbosity_info() diff --git a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py index f726466b10..a23267d709 100755 --- a/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py +++ b/src/transformers/models/xlnet/convert_xlnet_original_tf_checkpoint_to_pytorch.py @@ -20,16 +20,15 @@ import os import torch -from transformers import ( - CONFIG_NAME, - WEIGHTS_NAME, +from ...file_utils import CONFIG_NAME, WEIGHTS_NAME +from ...utils import logging +from . import ( XLNetConfig, XLNetForQuestionAnswering, XLNetForSequenceClassification, XLNetLMHeadModel, load_tf_weights_in_xlnet, ) -from transformers.utils import logging GLUE_TASKS_NUM_LABELS = { diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py index cee8b9838d..f22986b3e0 100644 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/to_replace_{{cookiecutter.lowercase_modelname}}.py @@ -24,71 +24,147 @@ ## ## Put '## COMMENT' to comment on the file. +# To replace in: "src/transformers/__init__.py" +# Below: " # PyTorch models structure" if generating PyTorch +# Replace with: +{% if cookiecutter.is_encoder_decoder_model == "False" %} + _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( + [ + "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", + "{{cookiecutter.camelcase_modelname}}ForMaskedLM", + "{{cookiecutter.camelcase_modelname}}ForCausalLM", + "{{cookiecutter.camelcase_modelname}}ForMultipleChoice", + "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", + "{{cookiecutter.camelcase_modelname}}ForSequenceClassification", + "{{cookiecutter.camelcase_modelname}}ForTokenClassification", + "{{cookiecutter.camelcase_modelname}}Layer", + "{{cookiecutter.camelcase_modelname}}Model", + "{{cookiecutter.camelcase_modelname}}PreTrainedModel", + "load_tf_weights_in_{{cookiecutter.lowercase_modelname}}", + ] + ) +{% else %} + _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( + [ + "{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", + "{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", + "{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", + "{{cookiecutter.camelcase_modelname}}ForSequenceClassification", + "{{cookiecutter.camelcase_modelname}}Model", + ] + ) +{% endif -%} +# End. + +# Below: " # TensorFlow models structure" if generating TensorFlow +# Replace with: +{% if cookiecutter.is_encoder_decoder_model == "False" %} + _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( + [ + "TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST", + "TF{{cookiecutter.camelcase_modelname}}ForMaskedLM", + "TF{{cookiecutter.camelcase_modelname}}ForCausalLM", + "TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice", + "TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering", + "TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification", + "TF{{cookiecutter.camelcase_modelname}}ForTokenClassification", + "TF{{cookiecutter.camelcase_modelname}}Layer", + "TF{{cookiecutter.camelcase_modelname}}Model", + "TF{{cookiecutter.camelcase_modelname}}PreTrainedModel", + ] + ) +{% else %} + _import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend( + [ + "TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration", + "TF{{cookiecutter.camelcase_modelname}}Model", + "TF{{cookiecutter.camelcase_modelname}}PreTrainedModel", + ] + ) +{% endif -%} +# End. + +# Below: " # Fast tokenizers" +# Replace with: + _import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast") +# End. + +# Below: " # Models" +# Replace with: + "models.{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config", "{{cookiecutter.camelcase_modelname}}Tokenizer"], +# End. # To replace in: "src/transformers/__init__.py" -# Below: "if is_torch_available():" if generating PyTorch +# Below: " if is_torch_available():" if generating PyTorch # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" %} - from .models.{{cookiecutter.lowercase_modelname}} import ( - {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, - {{cookiecutter.camelcase_modelname}}ForMaskedLM, - {{cookiecutter.camelcase_modelname}}ForCausalLM, - {{cookiecutter.camelcase_modelname}}ForMultipleChoice, - {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - {{cookiecutter.camelcase_modelname}}ForSequenceClassification, - {{cookiecutter.camelcase_modelname}}ForTokenClassification, - {{cookiecutter.camelcase_modelname}}Layer, - {{cookiecutter.camelcase_modelname}}Model, - {{cookiecutter.camelcase_modelname}}PreTrainedModel, - load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, - ) + from .models.{{cookiecutter.lowercase_modelname}} import ( + {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, + {{cookiecutter.camelcase_modelname}}ForMaskedLM, + {{cookiecutter.camelcase_modelname}}ForCausalLM, + {{cookiecutter.camelcase_modelname}}ForMultipleChoice, + {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, + {{cookiecutter.camelcase_modelname}}ForSequenceClassification, + {{cookiecutter.camelcase_modelname}}ForTokenClassification, + {{cookiecutter.camelcase_modelname}}Layer, + {{cookiecutter.camelcase_modelname}}Model, + {{cookiecutter.camelcase_modelname}}PreTrainedModel, + load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, + ) {% else %} - from .models.{{cookiecutter.lowercase_modelname}} import ( - {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, - {{cookiecutter.camelcase_modelname}}ForConditionalGeneration, - {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - {{cookiecutter.camelcase_modelname}}ForSequenceClassification, - {{cookiecutter.camelcase_modelname}}Model, - ) + from .models.{{cookiecutter.lowercase_modelname}} import ( + {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, + {{cookiecutter.camelcase_modelname}}ForConditionalGeneration, + {{cookiecutter.camelcase_modelname}}ForQuestionAnswering, + {{cookiecutter.camelcase_modelname}}ForSequenceClassification, + {{cookiecutter.camelcase_modelname}}Model, + ) {% endif -%} # End. -# Below: "if is_tf_available():" if generating TensorFlow +# Below: " if is_tf_available():" if generating TensorFlow # Replace with: {% if cookiecutter.is_encoder_decoder_model == "False" %} - from .models.{{cookiecutter.lowercase_modelname}} import ( - TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, - TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, - TF{{cookiecutter.camelcase_modelname}}ForCausalLM, - TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, - TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, - TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification, - TF{{cookiecutter.camelcase_modelname}}ForTokenClassification, - TF{{cookiecutter.camelcase_modelname}}Layer, - TF{{cookiecutter.camelcase_modelname}}Model, - TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, - ) + from .models.{{cookiecutter.lowercase_modelname}} import ( + TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, + TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, + TF{{cookiecutter.camelcase_modelname}}ForCausalLM, + TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, + TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, + TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification, + TF{{cookiecutter.camelcase_modelname}}ForTokenClassification, + TF{{cookiecutter.camelcase_modelname}}Layer, + TF{{cookiecutter.camelcase_modelname}}Model, + TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, + ) {% else %} - from .models.{{cookiecutter.lowercase_modelname}} import ( - TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, - TF{{cookiecutter.camelcase_modelname}}Model, - TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, - ) + from .models.{{cookiecutter.lowercase_modelname}} import ( + TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, + TF{{cookiecutter.camelcase_modelname}}Model, + TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, + ) {% endif -%} # End. -# Below: "if is_tokenizers_available():" +# Below: " if is_tokenizers_available():" # Replace with: - from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast + from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast # End. -# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" +# Below: " from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" # Replace with: -from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer + from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer # End. +# To replace in: "src/transformers/models/__init__.py" +# Below: "from . import (" +# Replace with: + {{cookiecutter.lowercase_modelname}}, +# End. + + # To replace in: "src/transformers/models/auto/configuration_auto.py" # Below: "# Add configs here" # Replace with: diff --git a/utils/check_dummies.py b/utils/check_dummies.py index 0960682b05..f254e5a2ca 100644 --- a/utils/check_dummies.py +++ b/utils/check_dummies.py @@ -23,237 +23,79 @@ import re PATH_TO_TRANSFORMERS = "src/transformers" _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") +_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$") + + +BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"] + DUMMY_CONSTANT = """ {0} = None """ -DUMMY_PT_PRETRAINED_CLASS = """ +DUMMY_PRETRAINED_CLASS = """ class {0}: def __init__(self, *args, **kwargs): - requires_pytorch(self) + requires_{1}(self) @classmethod def from_pretrained(self, *args, **kwargs): - requires_pytorch(self) + requires_{1}(self) """ -DUMMY_PT_CLASS = """ +DUMMY_CLASS = """ class {0}: def __init__(self, *args, **kwargs): - requires_pytorch(self) + requires_{1}(self) """ -DUMMY_PT_FUNCTION = """ +DUMMY_FUNCTION = """ def {0}(*args, **kwargs): - requires_pytorch({0}) + requires_{1}({0}) """ -DUMMY_TF_PRETRAINED_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_tf(self) - - @classmethod - def from_pretrained(self, *args, **kwargs): - requires_tf(self) -""" - -DUMMY_TF_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_tf(self) -""" - -DUMMY_TF_FUNCTION = """ -def {0}(*args, **kwargs): - requires_tf({0}) -""" - - -DUMMY_FLAX_PRETRAINED_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_flax(self) - - @classmethod - def from_pretrained(self, *args, **kwargs): - requires_flax(self) -""" - -DUMMY_FLAX_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_flax(self) -""" - -DUMMY_FLAX_FUNCTION = """ -def {0}(*args, **kwargs): - requires_flax({0}) -""" - - -DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_sentencepiece(self) - - @classmethod - def from_pretrained(self, *args, **kwargs): - requires_sentencepiece(self) -""" - -DUMMY_SENTENCEPIECE_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_sentencepiece(self) -""" - -DUMMY_SENTENCEPIECE_FUNCTION = """ -def {0}(*args, **kwargs): - requires_sentencepiece({0}) -""" - - -DUMMY_TOKENIZERS_PRETRAINED_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_tokenizers(self) - - @classmethod - def from_pretrained(self, *args, **kwargs): - requires_tokenizers(self) -""" - -DUMMY_TOKENIZERS_CLASS = """ -class {0}: - def __init__(self, *args, **kwargs): - requires_tokenizers(self) -""" - -DUMMY_TOKENIZERS_FUNCTION = """ -def {0}(*args, **kwargs): - requires_tokenizers({0}) -""" - -# Map all these to dummy type - -DUMMY_PRETRAINED_CLASS = { - "pt": DUMMY_PT_PRETRAINED_CLASS, - "tf": DUMMY_TF_PRETRAINED_CLASS, - "flax": DUMMY_FLAX_PRETRAINED_CLASS, - "sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS, - "tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS, -} - -DUMMY_CLASS = { - "pt": DUMMY_PT_CLASS, - "tf": DUMMY_TF_CLASS, - "flax": DUMMY_FLAX_CLASS, - "sentencepiece": DUMMY_SENTENCEPIECE_CLASS, - "tokenizers": DUMMY_TOKENIZERS_CLASS, -} - -DUMMY_FUNCTION = { - "pt": DUMMY_PT_FUNCTION, - "tf": DUMMY_TF_FUNCTION, - "flax": DUMMY_FLAX_FUNCTION, - "sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION, - "tokenizers": DUMMY_TOKENIZERS_FUNCTION, -} - - def read_init(): """ Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """ with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: lines = f.readlines() + # Get to the point we do the actual imports for type checking line_index = 0 - # Find where the SentencePiece imports begin - sentencepiece_objects = [] - while not lines[line_index].startswith("if is_sentencepiece_available():"): - line_index += 1 - line_index += 1 - - # Until we unindent, add SentencePiece objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): - line = lines[line_index] - search = _re_single_line_import.search(line) - if search is not None: - sentencepiece_objects += search.groups()[0].split(", ") - elif line.startswith(" "): - sentencepiece_objects.append(line[8:-2]) + while not lines[line_index].startswith("if TYPE_CHECKING"): line_index += 1 - # Find where the Tokenizers imports begin - tokenizers_objects = [] - while not lines[line_index].startswith("if is_tokenizers_available():"): - line_index += 1 - line_index += 1 + backend_specific_objects = {} + # Go through the end of the file + while line_index < len(lines): + # If the line is an if is_backemd_available, we grab all objects associated. + if _re_test_backend.search(lines[line_index]) is not None: + backend = _re_test_backend.search(lines[line_index]).groups()[0] + line_index += 1 - # Until we unindent, add Tokenizers objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): - line = lines[line_index] - search = _re_single_line_import.search(line) - if search is not None: - tokenizers_objects += search.groups()[0].split(", ") - elif line.startswith(" "): - tokenizers_objects.append(line[8:-2]) - line_index += 1 + # Ignore if backend isn't tracked for dummies. + if backend not in BACKENDS: + continue - # Find where the PyTorch imports begin - pt_objects = [] - while not lines[line_index].startswith("if is_torch_available():"): - line_index += 1 - line_index += 1 + objects = [] + # Until we unindent, add backend objects to the list + while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8): + line = lines[line_index] + single_line_import_search = _re_single_line_import.search(line) + if single_line_import_search is not None: + objects.extend(single_line_import_search.groups()[0].split(", ")) + elif line.startswith(" " * 12): + objects.append(line[12:-2]) + line_index += 1 - # Until we unindent, add PyTorch objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): - line = lines[line_index] - search = _re_single_line_import.search(line) - if search is not None: - pt_objects += search.groups()[0].split(", ") - elif line.startswith(" "): - pt_objects.append(line[8:-2]) - line_index += 1 + backend_specific_objects[backend] = objects + else: + line_index += 1 - # Find where the TF imports begin - tf_objects = [] - while not lines[line_index].startswith("if is_tf_available():"): - line_index += 1 - line_index += 1 - - # Until we unindent, add PyTorch objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): - line = lines[line_index] - search = _re_single_line_import.search(line) - if search is not None: - tf_objects += search.groups()[0].split(", ") - elif line.startswith(" "): - tf_objects.append(line[8:-2]) - line_index += 1 - - # Find where the FLAX imports begin - flax_objects = [] - while not lines[line_index].startswith("if is_flax_available():"): - line_index += 1 - line_index += 1 - - # Until we unindent, add PyTorch objects to the list - while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): - line = lines[line_index] - search = _re_single_line_import.search(line) - if search is not None: - flax_objects += search.groups()[0].split(", ") - elif line.startswith(" "): - flax_objects.append(line[8:-2]) - line_index += 1 - - return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects + return backend_specific_objects -def create_dummy_object(name, type="pt"): +def create_dummy_object(name, backend_name): """ Create the code for the dummy object corresponding to `name`.""" _pretrained = [ "Config" "ForCausalLM", @@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"): "Model", "Tokenizer", ] - assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"] if name.isupper(): return DUMMY_CONSTANT.format(name) elif name.islower(): - return (DUMMY_FUNCTION[type]).format(name) + return DUMMY_FUNCTION.format(name, backend_name) else: is_pretrained = False for part in _pretrained: @@ -278,114 +119,61 @@ def create_dummy_object(name, type="pt"): is_pretrained = True break if is_pretrained: - template = DUMMY_PRETRAINED_CLASS[type] + return DUMMY_PRETRAINED_CLASS.format(name, backend_name) else: - template = DUMMY_CLASS[type] - return template.format(name) + return DUMMY_CLASS.format(name, backend_name) def create_dummy_files(): """ Create the content of the dummy files. """ - sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init() + backend_specific_objects = read_init() + # For special correspondence backend to module name as used in the function requires_modulename + module_names = {"torch": "pytorch"} + dummy_files = {} - sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n" - sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects]) + for backend, objects in backend_specific_objects.items(): + backend_name = module_names.get(backend, backend) + dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" + dummy_file += f"from ..file_utils import requires_{backend_name}\n\n" + dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects]) + dummy_files[backend] = dummy_file - tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n" - tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects]) - - pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - pt_dummies += "from ..file_utils import requires_pytorch\n\n" - pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects]) - - tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - tf_dummies += "from ..file_utils import requires_tf\n\n" - tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects]) - - flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" - flax_dummies += "from ..file_utils import requires_flax\n\n" - flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects]) - - return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies + return dummy_files def check_dummies(overwrite=False): """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """ - sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files() + dummy_files = create_dummy_files() + # For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py + short_names = {"torch": "pt"} + + # Locate actual dummy modules and read their content. path = os.path.join(PATH_TO_TRANSFORMERS, "utils") - sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py") - tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py") - pt_file = os.path.join(path, "dummy_pt_objects.py") - tf_file = os.path.join(path, "dummy_tf_objects.py") - flax_file = os.path.join(path, "dummy_flax_objects.py") + dummy_file_paths = { + backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py") + for backend in dummy_files.keys() + } - with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f: - actual_sentencepiece_dummies = f.read() - with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f: - actual_tokenizers_dummies = f.read() - with open(pt_file, "r", encoding="utf-8", newline="\n") as f: - actual_pt_dummies = f.read() - with open(tf_file, "r", encoding="utf-8", newline="\n") as f: - actual_tf_dummies = f.read() - with open(flax_file, "r", encoding="utf-8", newline="\n") as f: - actual_flax_dummies = f.read() + actual_dummies = {} + for backend, file_path in dummy_file_paths.items(): + with open(file_path, "r", encoding="utf-8", newline="\n") as f: + actual_dummies[backend] = f.read() - if sentencepiece_dummies != actual_sentencepiece_dummies: - if overwrite: - print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.") - with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f: - f.write(sentencepiece_dummies) - else: - raise ValueError( - "The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.", - "Run `make fix-copies` to fix this.", - ) - - if tokenizers_dummies != actual_tokenizers_dummies: - if overwrite: - print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.") - with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f: - f.write(tokenizers_dummies) - else: - raise ValueError( - "The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.", - "Run `make fix-copies` to fix this.", - ) - - if pt_dummies != actual_pt_dummies: - if overwrite: - print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.") - with open(pt_file, "w", encoding="utf-8", newline="\n") as f: - f.write(pt_dummies) - else: - raise ValueError( - "The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.", - "Run `make fix-copies` to fix this.", - ) - - if tf_dummies != actual_tf_dummies: - if overwrite: - print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.") - with open(tf_file, "w", encoding="utf-8", newline="\n") as f: - f.write(tf_dummies) - else: - raise ValueError( - "The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.", - "Run `make fix-copies` to fix this.", - ) - - if flax_dummies != actual_flax_dummies: - if overwrite: - print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.") - with open(flax_file, "w", encoding="utf-8", newline="\n") as f: - f.write(flax_dummies) - else: - raise ValueError( - "The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.", - "Run `make fix-copies` to fix this.", - ) + for backend in dummy_files.keys(): + if dummy_files[backend] != actual_dummies[backend]: + if overwrite: + print( + f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main " + "__init__ has new objects." + ) + with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f: + f.write(dummy_files[backend]) + else: + raise ValueError( + "The main __init__ has objects that are not present in " + f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` " + "to fix this." + ) if __name__ == "__main__": diff --git a/utils/check_repo.py b/utils/check_repo.py index aefac35684..e40500e06b 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [ def ignore_undocumented(name): """Rules to determine if `name` should be undocumented.""" # NOT DOCUMENTED ON PURPOSE. - # Magic attributes are not documented. - if name.startswith("__"): - return True # Constants uppercase are not documented. if name.isupper(): return True @@ -459,7 +456,9 @@ def ignore_undocumented(name): def check_all_objects_are_documented(): """ Check all models are properly documented.""" documented_objs = find_all_documented_objects() - undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)] + modules = transformers._modules + objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")] + undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)] if len(undocumented_objs) > 0: raise Exception( "The following objects are in the public init so should be documented:\n - "