Transformers fast import part 2 (#9446)

* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2021-01-07 09:36:14 -05:00
committed by GitHub
parent a400fe8931
commit 758ed3332b
56 changed files with 2426 additions and 1377 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -30,9 +30,8 @@ from multiprocessing import Pipe, Process, Queue
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import Callable, Iterable, List, NamedTuple, Optional, Union from typing import Callable, Iterable, List, NamedTuple, Optional, Union
from transformers import AutoConfig, PretrainedConfig from .. import AutoConfig, PretrainedConfig
from transformers import __version__ as version from .. import __version__ as version
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
from ..utils import logging from ..utils import logging
from .benchmark_args_utils import BenchmarkArguments from .benchmark_args_utils import BenchmarkArguments

View File

@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
from pathlib import Path from pathlib import Path
from typing import List from typing import List
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
try: try:

View File

@@ -14,9 +14,8 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
def convert_command_factory(args: Namespace): def convert_command_factory(args: Namespace):
@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
def run(self): def run(self):
if self._model_type == "albert": if self._model_type == "albert":
try: try:
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import ( from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch, convert_tf_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "bert": elif self._model_type == "bert":
try: try:
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import ( from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch, convert_tf_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "funnel": elif self._model_type == "funnel":
try: try:
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import ( from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch, convert_tf_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt": elif self._model_type == "gpt":
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import ( from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
convert_openai_checkpoint_to_pytorch, convert_openai_checkpoint_to_pytorch,
) )
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "transfo_xl": elif self._model_type == "transfo_xl":
try: try:
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import ( from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
convert_transfo_xl_checkpoint_to_pytorch, convert_transfo_xl_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
) )
elif self._model_type == "gpt2": elif self._model_type == "gpt2":
try: try:
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import ( from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
convert_gpt2_checkpoint_to_pytorch, convert_gpt2_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output) convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet": elif self._model_type == "xlnet":
try: try:
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import ( from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
convert_xlnet_checkpoint_to_pytorch, convert_xlnet_checkpoint_to_pytorch,
) )
except ImportError: except ImportError:
@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
) )
elif self._model_type == "xlm": elif self._model_type == "xlm":
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import ( from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
convert_xlm_checkpoint_to_pytorch, convert_xlm_checkpoint_to_pytorch,
) )
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output) convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
elif self._model_type == "lxmert": elif self._model_type == "lxmert":
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import ( from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
convert_lxmert_checkpoint_to_pytorch, convert_lxmert_checkpoint_to_pytorch,
) )

View File

@@ -14,7 +14,7 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand from . import BaseTransformersCLICommand
def download_command_factory(args): def download_command_factory(args):
@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
self._force = force self._force = force
def run(self): def run(self):
from transformers import AutoModel, AutoTokenizer from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)

View File

@@ -15,9 +15,9 @@
import platform import platform
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers import __version__ as version from .. import __version__ as version
from transformers import is_tf_available, is_torch_available from ..file_utils import is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand from . import BaseTransformersCLICommand
def info_command_factory(_): def info_command_factory(_):

View File

@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
from typing import Dict, List, Optional from typing import Dict, List, Optional
import requests import requests
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -14,10 +14,9 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -15,11 +15,9 @@
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional from typing import Any, List, Optional
from transformers import Pipeline from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
try: try:

View File

@@ -15,11 +15,11 @@
import os import os
from argparse import ArgumentParser, Namespace from argparse import ArgumentParser, Namespace
from transformers import SingleSentenceClassificationProcessor as Processor from ..data import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available from ..file_utils import is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand from ..pipelines import TextClassificationPipeline
from ..utils import logging from ..utils import logging
from . import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available(): if not is_tf_available() and not is_torch_available():

View File

@@ -15,14 +15,14 @@
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands.add_new_model import AddNewModelCommand from .add_new_model import AddNewModelCommand
from transformers.commands.convert import ConvertCommand from .convert import ConvertCommand
from transformers.commands.download import DownloadCommand from .download import DownloadCommand
from transformers.commands.env import EnvironmentCommand from .env import EnvironmentCommand
from transformers.commands.lfs import LfsCommands from .lfs import LfsCommands
from transformers.commands.run import RunCommand from .run import RunCommand
from transformers.commands.serving import ServeCommand from .serving import ServeCommand
from transformers.commands.user import UserCommands from .user import UserCommands
def main(): def main():

View File

@@ -20,8 +20,9 @@ from getpass import getpass
from typing import List, Union from typing import List, Union
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder from ..hf_api import HfApi, HfFolder
from . import BaseTransformersCLICommand
UPLOAD_MAX_FILES = 15 UPLOAD_MAX_FILES = 15

View File

@@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple
from packaging.version import Version, parse from packaging.version import Version, parse
from transformers import is_tf_available, is_torch_available from .file_utils import ModelOutput, is_tf_available, is_torch_available
from transformers.file_utils import ModelOutput from .pipelines import Pipeline, pipeline
from transformers.pipelines import Pipeline, pipeline from .tokenization_utils import BatchEncoding
from transformers.tokenization_utils import BatchEncoding
# This is the minimal required version to # This is the minimal required version to

View File

@@ -18,7 +18,7 @@
import argparse import argparse
import os import os
from transformers import ( from . import (
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
BART_PRETRAINED_MODEL_ARCHIVE_LIST, BART_PRETRAINED_MODEL_ARCHIVE_LIST,
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
@@ -87,15 +87,15 @@ from transformers import (
is_torch_available, is_torch_available,
load_pytorch_checkpoint_in_tf2_model, load_pytorch_checkpoint_in_tf2_model,
) )
from transformers.file_utils import hf_bucket_url from .file_utils import hf_bucket_url
from transformers.utils import logging from .utils import logging
if is_torch_available(): if is_torch_available():
import numpy as np import numpy as np
import torch import torch
from transformers import ( from . import (
AlbertForPreTraining, AlbertForPreTraining,
BartForConditionalGeneration, BartForConditionalGeneration,
BertForPreTraining, BertForPreTraining,

View File

@@ -18,8 +18,9 @@ import argparse
import os import os
import transformers import transformers
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from transformers.utils import logging from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
from .utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -17,7 +17,7 @@
import argparse import argparse
from transformers import ( from . import (
BertConfig, BertConfig,
BertGenerationConfig, BertGenerationConfig,
BertGenerationDecoder, BertGenerationDecoder,

View File

@@ -27,8 +27,7 @@ import math
import re import re
import string import string
from transformers import BasicTokenizer from ...models.bert import BasicTokenizer
from ...utils import logging from ...utils import logging

View File

@@ -17,15 +17,14 @@ import unittest
import timeout_decorator import timeout_decorator
from transformers import is_torch_available from ..file_utils import cached_property, is_torch_available
from transformers.file_utils import cached_property from ..testing_utils import require_torch
from transformers.testing_utils import require_torch
if is_torch_available(): if is_torch_available():
import torch import torch
from transformers import MarianConfig, MarianMTModel from ..models.marian import MarianConfig, MarianMTModel
@require_torch @require_torch

View File

@@ -33,6 +33,7 @@ from dataclasses import fields
from functools import partial, wraps from functools import partial, wraps
from hashlib import sha256 from hashlib import sha256
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
from urllib.parse import urlparse from urllib.parse import urlparse
from zipfile import ZipFile, is_zipfile from zipfile import ZipFile, is_zipfile
@@ -41,7 +42,6 @@ import numpy as np
from packaging import version from packaging import version
from tqdm.auto import tqdm from tqdm.auto import tqdm
import importlib_metadata
import requests import requests
from filelock import FileLock from filelock import FileLock
@@ -50,6 +50,13 @@ from .hf_api import HfFolder
from .utils import logging from .utils import logging
# The package importlib_metadata is in a different place, depending on the python version.
if version.parse(sys.version) < version.parse("3.8"):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
logger = logging.get_logger(__name__) # pylint: disable=invalid-name logger = logging.get_logger(__name__) # pylint: disable=invalid-name
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"} ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
_scatter_available = importlib.util.find_spec("torch_scatter") is not None _scatter_available = importlib.util.find_spec("torch_scatter") is not None
try: try:
_scatter_version = importlib_metadata.version("torch_scatterr") _scatter_version = importlib_metadata.version("torch_scatter")
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}") logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
except importlib_metadata.PackageNotFoundError: except importlib_metadata.PackageNotFoundError:
_scatter_available = False _scatter_available = False
@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
Convert self to a tuple containing all the attributes/keys that are not ``None``. Convert self to a tuple containing all the attributes/keys that are not ``None``.
""" """
return tuple(self[k] for k in self.keys()) return tuple(self[k] for k in self.keys())
class _BaseLazyModule(ModuleType):
"""
Module class that surfaces all objects but only performs associated imports when the objects are requested.
"""
# Very heavily inspired by optuna.integration._IntegrationModule
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
def __init__(self, name, import_structure):
super().__init__(name)
self._modules = set(import_structure.keys())
self._class_to_module = {}
for key, values in import_structure.items():
for value in values:
self._class_to_module[value] = key
# Needed for autocompletion in an IDE
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
# Needed for autocompletion in an IDE
def __dir__(self):
return super().__dir__() + self.__all__
def __getattr__(self, name: str) -> Any:
if name in self._modules:
value = self._get_module(name)
elif name in self._class_to_module.keys():
module = self._get_module(self._class_to_module[name])
value = getattr(module, name)
else:
raise AttributeError(f"module {self.__name__} has no attribute {name}")
setattr(self, name, value)
return value
def _get_module(self, module_name: str) -> ModuleType:
raise NotImplementedError

View File

@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
# comet_ml requires to be imported before any ML frameworks # comet_ml requires to be imported before any ML frameworks
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED" _has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
if _has_comet: if _has_comet:
try: try:
import comet_ml # noqa: F401 import comet_ml # noqa: F401

View File

@@ -0,0 +1,68 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2020 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from . import (
albert,
auto,
bart,
barthez,
bert,
bert_generation,
bert_japanese,
bertweet,
blenderbot,
blenderbot_small,
camembert,
ctrl,
deberta,
dialogpt,
distilbert,
dpr,
electra,
encoder_decoder,
flaubert,
fsmt,
funnel,
gpt2,
herbert,
layoutlm,
led,
longformer,
lxmert,
marian,
mbart,
mmbt,
mobilebert,
mpnet,
mt5,
openai,
pegasus,
phobert,
prophetnet,
rag,
reformer,
retribert,
roberta,
squeezebert,
t5,
tapas,
transfo_xl,
xlm,
xlm_roberta,
xlnet,
)

View File

@@ -19,8 +19,8 @@ import argparse
import torch import torch
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert from ...utils import logging
from transformers.utils import logging from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -23,15 +23,9 @@ import fairseq
import torch import torch
from packaging import version from packaging import version
from transformers import ( from ...utils import logging
BartConfig, from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
BartForConditionalGeneration, from .modeling_bart import _make_linear_from_emb
BartForSequenceClassification,
BartModel,
BartTokenizer,
)
from transformers.models.bart.modeling_bart import _make_linear_from_emb
from transformers.utils import logging
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"] FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]

View File

@@ -15,8 +15,7 @@
from typing import List, Optional from typing import List, Optional
from transformers import add_start_docstrings from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta import RobertaTokenizer from ..roberta.tokenization_roberta import RobertaTokenizer

View File

@@ -15,8 +15,7 @@
from typing import List, Optional from typing import List, Optional
from transformers import add_start_docstrings from ...file_utils import add_start_docstrings
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
from ...utils import logging from ...utils import logging
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast

View File

@@ -28,8 +28,8 @@ import re
import tensorflow as tf import tensorflow as tf
import torch import torch
from transformers import BertConfig, BertModel from ...utils import logging
from transformers.utils import logging from . import BertConfig, BertModel
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -19,8 +19,8 @@ import argparse
import torch import torch
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert from ...utils import logging
from transformers.utils import logging from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -22,7 +22,7 @@ import numpy as np
import tensorflow as tf import tensorflow as tf
import torch import torch
from transformers import BertModel from . import BertModel
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str): def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):

View File

@@ -18,8 +18,8 @@ import argparse
import torch import torch
from transformers import BartConfig, BartForConditionalGeneration from ...models.bart import BartConfig, BartForConditionalGeneration
from transformers.utils import logging from ...utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -17,7 +17,7 @@ import os
import torch import torch
from transformers.file_utils import WEIGHTS_NAME from ...file_utils import WEIGHTS_NAME
DIALOGPT_MODELS = ["small", "medium", "large"] DIALOGPT_MODELS = ["small", "medium", "large"]

View File

@@ -19,7 +19,8 @@ from pathlib import Path
import torch import torch
from torch.serialization import default_restore_location from torch.serialization import default_restore_location
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader from ...models.bert import BertConfig
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
CheckpointState = collections.namedtuple( CheckpointState = collections.namedtuple(

View File

@@ -19,8 +19,8 @@ import argparse
import torch import torch
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra from ...utils import logging
from transformers.utils import logging from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -31,9 +31,10 @@ import torch
from fairseq import hub_utils from fairseq import hub_utils
from fairseq.data.dictionary import Dictionary from fairseq.data.dictionary import Dictionary
from transformers import WEIGHTS_NAME, logging from ...file_utils import WEIGHTS_NAME
from transformers.models.fsmt import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE from ...utils import logging
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
logging.set_verbosity_warning() logging.set_verbosity_warning()

View File

@@ -20,7 +20,7 @@ import logging
import torch import torch
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@@ -19,8 +19,9 @@ import argparse
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2 from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging from ...utils import logging
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -20,7 +20,7 @@ import argparse
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
from transformers import LongformerForQuestionAnswering, LongformerModel from . import LongformerForQuestionAnswering, LongformerModel
class LightningModel(pl.LightningModule): class LightningModel(pl.LightningModule):

View File

@@ -20,7 +20,7 @@ import logging
import torch import torch
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)

View File

@@ -17,7 +17,7 @@ import os
from pathlib import Path from pathlib import Path
from typing import List, Tuple from typing import List, Tuple
from transformers.models.marian.convert_marian_to_pytorch import ( from .convert_marian_to_pytorch import (
FRONT_MATTER_TEMPLATE, FRONT_MATTER_TEMPLATE,
_parse_readme, _parse_readme,
convert_all_sentencepiece_models, convert_all_sentencepiece_models,

View File

@@ -26,8 +26,8 @@ import numpy as np
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import MarianConfig, MarianMTModel, MarianTokenizer from ...hf_api import HfApi
from transformers.hf_api import HfApi from . import MarianConfig, MarianMTModel, MarianTokenizer
def remove_suffix(text: str, suffix: str): def remove_suffix(text: str, suffix: str):

View File

@@ -16,9 +16,9 @@ import argparse
import torch import torch
from transformers import BartForConditionalGeneration, MBartConfig from ..bart import BartForConditionalGeneration
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_ from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
from . import MBartConfig
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"): def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):

View File

@@ -16,8 +16,8 @@ import argparse
import torch import torch
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert from ...utils import logging
from transformers.utils import logging from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -19,8 +19,9 @@ import argparse
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.utils import logging from ...utils import logging
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -22,8 +22,8 @@ import tensorflow as tf
import torch import torch
from tqdm import tqdm from tqdm import tqdm
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params from .configuration_pegasus import DEFAULTS, task_specific_params
PATTERNS = [ PATTERNS = [

View File

@@ -19,8 +19,6 @@ import argparse
import torch import torch
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here # transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively # original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
from transformers_old.modeling_prophetnet import ( from transformers_old.modeling_prophetnet import (
@@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import (
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld, XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
) )
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -21,8 +21,8 @@ import pickle
import numpy as np import numpy as np
import torch import torch
from transformers import ReformerConfig, ReformerModelWithLMHead from ...utils import logging
from transformers.utils import logging from . import ReformerConfig, ReformerModelWithLMHead
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version from packaging import version
from transformers.models.bert.modeling_bert import ( from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
BertIntermediate, from ...utils import logging
BertLayer, from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.models.roberta.modeling_roberta import (
RobertaConfig,
RobertaForMaskedLM,
RobertaForSequenceClassification,
)
from transformers.utils import logging
if version.parse(fairseq.__version__) < version.parse("0.9.0"): if version.parse(fairseq.__version__) < version.parse("0.9.0"):

View File

@@ -17,8 +17,8 @@
import argparse import argparse
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5 from ...utils import logging
from transformers.utils import logging from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -22,8 +22,6 @@ from typing import Tuple
import tensorflow as tf import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings
from ...activations_tf import get_tf_activation from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
@@ -42,6 +40,7 @@ from ...modeling_tf_utils import (
TFCausalLanguageModelingLoss, TFCausalLanguageModelingLoss,
TFPreTrainedModel, TFPreTrainedModel,
TFSharedEmbeddings, TFSharedEmbeddings,
TFWrappedEmbeddings,
input_processing, input_processing,
keras_serializable, keras_serializable,
shape_list, shape_list,

View File

@@ -17,16 +17,16 @@
import argparse import argparse
from transformers.models.tapas.modeling_tapas import ( from ...utils import logging
from . import (
TapasConfig, TapasConfig,
TapasForMaskedLM, TapasForMaskedLM,
TapasForQuestionAnswering, TapasForQuestionAnswering,
TapasForSequenceClassification, TapasForSequenceClassification,
TapasModel, TapasModel,
TapasTokenizer,
load_tf_weights_in_tapas, load_tf_weights_in_tapas,
) )
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union
import numpy as np import numpy as np
from transformers import add_end_docstrings from ...file_utils import add_end_docstrings, is_pandas_available
from ...file_utils import is_pandas_available
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
from ...tokenization_utils_base import ( from ...tokenization_utils_base import (
ENCODE_KWARGS_DOCSTRING, ENCODE_KWARGS_DOCSTRING,

View File

@@ -22,16 +22,11 @@ import sys
import torch import torch
import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers import ( from ...utils import logging
CONFIG_NAME, from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
WEIGHTS_NAME, from . import tokenization_transfo_xl as data_utils
TransfoXLConfig, from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
TransfoXLLMHeadModel,
load_tf_weights_in_transfo_xl,
)
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
from transformers.utils import logging
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -21,9 +21,9 @@ import json
import numpy import numpy
import torch import torch
from transformers import CONFIG_NAME, WEIGHTS_NAME from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES from ...utils import logging
from transformers.utils import logging from .tokenization_xlm import VOCAB_FILES_NAMES
logging.set_verbosity_info() logging.set_verbosity_info()

View File

@@ -20,16 +20,15 @@ import os
import torch import torch
from transformers import ( from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
CONFIG_NAME, from ...utils import logging
WEIGHTS_NAME, from . import (
XLNetConfig, XLNetConfig,
XLNetForQuestionAnswering, XLNetForQuestionAnswering,
XLNetForSequenceClassification, XLNetForSequenceClassification,
XLNetLMHeadModel, XLNetLMHeadModel,
load_tf_weights_in_xlnet, load_tf_weights_in_xlnet,
) )
from transformers.utils import logging
GLUE_TASKS_NUM_LABELS = { GLUE_TASKS_NUM_LABELS = {

View File

@@ -24,71 +24,147 @@
## ##
## Put '## COMMENT' to comment on the file. ## Put '## COMMENT' to comment on the file.
# To replace in: "src/transformers/__init__.py"
# Below: " # PyTorch models structure" if generating PyTorch
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"{{cookiecutter.camelcase_modelname}}Layer",
"{{cookiecutter.camelcase_modelname}}Model",
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
]
)
{% else %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"{{cookiecutter.camelcase_modelname}}Model",
]
)
{% endif -%}
# End.
# Below: " # TensorFlow models structure" if generating TensorFlow
# Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
"TF{{cookiecutter.camelcase_modelname}}Layer",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
)
{% else %}
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
[
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
"TF{{cookiecutter.camelcase_modelname}}Model",
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
]
)
{% endif -%}
# End.
# Below: " # Fast tokenizers"
# Replace with:
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
# End.
# Below: " # Models"
# Replace with:
"models.{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config", "{{cookiecutter.camelcase_modelname}}Tokenizer"],
# End.
# To replace in: "src/transformers/__init__.py" # To replace in: "src/transformers/__init__.py"
# Below: "if is_torch_available():" if generating PyTorch # Below: " if is_torch_available():" if generating PyTorch
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForMaskedLM, {{cookiecutter.camelcase_modelname}}ForMaskedLM,
{{cookiecutter.camelcase_modelname}}ForCausalLM, {{cookiecutter.camelcase_modelname}}ForCausalLM,
{{cookiecutter.camelcase_modelname}}ForMultipleChoice, {{cookiecutter.camelcase_modelname}}ForMultipleChoice,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}ForTokenClassification, {{cookiecutter.camelcase_modelname}}ForTokenClassification,
{{cookiecutter.camelcase_modelname}}Layer, {{cookiecutter.camelcase_modelname}}Layer,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
{{cookiecutter.camelcase_modelname}}PreTrainedModel, {{cookiecutter.camelcase_modelname}}PreTrainedModel,
load_tf_weights_in_{{cookiecutter.lowercase_modelname}}, load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
) )
{% else %} {% else %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, {{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, {{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, {{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
{{cookiecutter.camelcase_modelname}}ForSequenceClassification, {{cookiecutter.camelcase_modelname}}ForSequenceClassification,
{{cookiecutter.camelcase_modelname}}Model, {{cookiecutter.camelcase_modelname}}Model,
) )
{% endif -%} {% endif -%}
# End. # End.
# Below: "if is_tf_available():" if generating TensorFlow # Below: " if is_tf_available():" if generating TensorFlow
# Replace with: # Replace with:
{% if cookiecutter.is_encoder_decoder_model == "False" %} {% if cookiecutter.is_encoder_decoder_model == "False" %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST, TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM, TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
TF{{cookiecutter.camelcase_modelname}}ForCausalLM, TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice, TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering, TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification, TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification, TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
TF{{cookiecutter.camelcase_modelname}}Layer, TF{{cookiecutter.camelcase_modelname}}Layer,
TF{{cookiecutter.camelcase_modelname}}Model, TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% else %} {% else %}
from .models.{{cookiecutter.lowercase_modelname}} import ( from .models.{{cookiecutter.lowercase_modelname}} import (
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration, TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
TF{{cookiecutter.camelcase_modelname}}Model, TF{{cookiecutter.camelcase_modelname}}Model,
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel, TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
) )
{% endif -%} {% endif -%}
# End. # End.
# Below: "if is_tokenizers_available():" # Below: " if is_tokenizers_available():"
# Replace with: # Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
# End. # End.
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig" # Below: " from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
# Replace with: # Replace with:
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
# End. # End.
# To replace in: "src/transformers/models/__init__.py"
# Below: "from . import ("
# Replace with:
{{cookiecutter.lowercase_modelname}},
# End.
# To replace in: "src/transformers/models/auto/configuration_auto.py" # To replace in: "src/transformers/models/auto/configuration_auto.py"
# Below: "# Add configs here" # Below: "# Add configs here"
# Replace with: # Replace with:

View File

@@ -23,237 +23,79 @@ import re
PATH_TO_TRANSFORMERS = "src/transformers" PATH_TO_TRANSFORMERS = "src/transformers"
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n") _re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
DUMMY_CONSTANT = """ DUMMY_CONSTANT = """
{0} = None {0} = None
""" """
DUMMY_PT_PRETRAINED_CLASS = """ DUMMY_PRETRAINED_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_{1}(self)
@classmethod @classmethod
def from_pretrained(self, *args, **kwargs): def from_pretrained(self, *args, **kwargs):
requires_pytorch(self) requires_{1}(self)
""" """
DUMMY_PT_CLASS = """ DUMMY_CLASS = """
class {0}: class {0}:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_pytorch(self) requires_{1}(self)
""" """
DUMMY_PT_FUNCTION = """ DUMMY_FUNCTION = """
def {0}(*args, **kwargs): def {0}(*args, **kwargs):
requires_pytorch({0}) requires_{1}({0})
""" """
DUMMY_TF_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tf(self)
"""
DUMMY_TF_FUNCTION = """
def {0}(*args, **kwargs):
requires_tf({0})
"""
DUMMY_FLAX_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_flax(self)
"""
DUMMY_FLAX_FUNCTION = """
def {0}(*args, **kwargs):
requires_flax({0})
"""
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_sentencepiece(self)
"""
DUMMY_SENTENCEPIECE_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_sentencepiece(self)
"""
DUMMY_SENTENCEPIECE_FUNCTION = """
def {0}(*args, **kwargs):
requires_sentencepiece({0})
"""
DUMMY_TOKENIZERS_PRETRAINED_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tokenizers(self)
"""
DUMMY_TOKENIZERS_CLASS = """
class {0}:
def __init__(self, *args, **kwargs):
requires_tokenizers(self)
"""
DUMMY_TOKENIZERS_FUNCTION = """
def {0}(*args, **kwargs):
requires_tokenizers({0})
"""
# Map all these to dummy type
DUMMY_PRETRAINED_CLASS = {
"pt": DUMMY_PT_PRETRAINED_CLASS,
"tf": DUMMY_TF_PRETRAINED_CLASS,
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
}
DUMMY_CLASS = {
"pt": DUMMY_PT_CLASS,
"tf": DUMMY_TF_CLASS,
"flax": DUMMY_FLAX_CLASS,
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
"tokenizers": DUMMY_TOKENIZERS_CLASS,
}
DUMMY_FUNCTION = {
"pt": DUMMY_PT_FUNCTION,
"tf": DUMMY_TF_FUNCTION,
"flax": DUMMY_FLAX_FUNCTION,
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
}
def read_init(): def read_init():
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """ """ Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f: with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
lines = f.readlines() lines = f.readlines()
# Get to the point we do the actual imports for type checking
line_index = 0 line_index = 0
# Find where the SentencePiece imports begin while not lines[line_index].startswith("if TYPE_CHECKING"):
sentencepiece_objects = []
while not lines[line_index].startswith("if is_sentencepiece_available():"):
line_index += 1
line_index += 1
# Until we unindent, add SentencePiece objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
sentencepiece_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
sentencepiece_objects.append(line[8:-2])
line_index += 1 line_index += 1
# Find where the Tokenizers imports begin backend_specific_objects = {}
tokenizers_objects = [] # Go through the end of the file
while not lines[line_index].startswith("if is_tokenizers_available():"): while line_index < len(lines):
line_index += 1 # If the line is an if is_backemd_available, we grab all objects associated.
line_index += 1 if _re_test_backend.search(lines[line_index]) is not None:
backend = _re_test_backend.search(lines[line_index]).groups()[0]
line_index += 1
# Until we unindent, add Tokenizers objects to the list # Ignore if backend isn't tracked for dummies.
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): if backend not in BACKENDS:
line = lines[line_index] continue
search = _re_single_line_import.search(line)
if search is not None:
tokenizers_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
tokenizers_objects.append(line[8:-2])
line_index += 1
# Find where the PyTorch imports begin objects = []
pt_objects = [] # Until we unindent, add backend objects to the list
while not lines[line_index].startswith("if is_torch_available():"): while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
line_index += 1 line = lines[line_index]
line_index += 1 single_line_import_search = _re_single_line_import.search(line)
if single_line_import_search is not None:
objects.extend(single_line_import_search.groups()[0].split(", "))
elif line.startswith(" " * 12):
objects.append(line[12:-2])
line_index += 1
# Until we unindent, add PyTorch objects to the list backend_specific_objects[backend] = objects
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "): else:
line = lines[line_index] line_index += 1
search = _re_single_line_import.search(line)
if search is not None:
pt_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
pt_objects.append(line[8:-2])
line_index += 1
# Find where the TF imports begin return backend_specific_objects
tf_objects = []
while not lines[line_index].startswith("if is_tf_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
tf_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
tf_objects.append(line[8:-2])
line_index += 1
# Find where the FLAX imports begin
flax_objects = []
while not lines[line_index].startswith("if is_flax_available():"):
line_index += 1
line_index += 1
# Until we unindent, add PyTorch objects to the list
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
line = lines[line_index]
search = _re_single_line_import.search(line)
if search is not None:
flax_objects += search.groups()[0].split(", ")
elif line.startswith(" "):
flax_objects.append(line[8:-2])
line_index += 1
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
def create_dummy_object(name, type="pt"): def create_dummy_object(name, backend_name):
""" Create the code for the dummy object corresponding to `name`.""" """ Create the code for the dummy object corresponding to `name`."""
_pretrained = [ _pretrained = [
"Config" "ForCausalLM", "Config" "ForCausalLM",
@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"):
"Model", "Model",
"Tokenizer", "Tokenizer",
] ]
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
if name.isupper(): if name.isupper():
return DUMMY_CONSTANT.format(name) return DUMMY_CONSTANT.format(name)
elif name.islower(): elif name.islower():
return (DUMMY_FUNCTION[type]).format(name) return DUMMY_FUNCTION.format(name, backend_name)
else: else:
is_pretrained = False is_pretrained = False
for part in _pretrained: for part in _pretrained:
@@ -278,114 +119,61 @@ def create_dummy_object(name, type="pt"):
is_pretrained = True is_pretrained = True
break break
if is_pretrained: if is_pretrained:
template = DUMMY_PRETRAINED_CLASS[type] return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
else: else:
template = DUMMY_CLASS[type] return DUMMY_CLASS.format(name, backend_name)
return template.format(name)
def create_dummy_files(): def create_dummy_files():
""" Create the content of the dummy files. """ """ Create the content of the dummy files. """
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init() backend_specific_objects = read_init()
# For special correspondence backend to module name as used in the function requires_modulename
module_names = {"torch": "pytorch"}
dummy_files = {}
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" for backend, objects in backend_specific_objects.items():
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n" backend_name = module_names.get(backend, backend)
sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects]) dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
dummy_files[backend] = dummy_file
tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n" return dummy_files
tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n"
tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects])
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects])
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
tf_dummies += "from ..file_utils import requires_tf\n\n"
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
flax_dummies += "from ..file_utils import requires_flax\n\n"
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
def check_dummies(overwrite=False): def check_dummies(overwrite=False):
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """ """ Check if the dummy files are up to date and maybe `overwrite` with the right content. """
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files() dummy_files = create_dummy_files()
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
short_names = {"torch": "pt"}
# Locate actual dummy modules and read their content.
path = os.path.join(PATH_TO_TRANSFORMERS, "utils") path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py") dummy_file_paths = {
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py") backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
pt_file = os.path.join(path, "dummy_pt_objects.py") for backend in dummy_files.keys()
tf_file = os.path.join(path, "dummy_tf_objects.py") }
flax_file = os.path.join(path, "dummy_flax_objects.py")
with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f: actual_dummies = {}
actual_sentencepiece_dummies = f.read() for backend, file_path in dummy_file_paths.items():
with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f: with open(file_path, "r", encoding="utf-8", newline="\n") as f:
actual_tokenizers_dummies = f.read() actual_dummies[backend] = f.read()
with open(pt_file, "r", encoding="utf-8", newline="\n") as f:
actual_pt_dummies = f.read()
with open(tf_file, "r", encoding="utf-8", newline="\n") as f:
actual_tf_dummies = f.read()
with open(flax_file, "r", encoding="utf-8", newline="\n") as f:
actual_flax_dummies = f.read()
if sentencepiece_dummies != actual_sentencepiece_dummies: for backend in dummy_files.keys():
if overwrite: if dummy_files[backend] != actual_dummies[backend]:
print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.") if overwrite:
with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f: print(
f.write(sentencepiece_dummies) f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
else: "__init__ has new objects."
raise ValueError( )
"The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.", with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
"Run `make fix-copies` to fix this.", f.write(dummy_files[backend])
) else:
raise ValueError(
if tokenizers_dummies != actual_tokenizers_dummies: "The main __init__ has objects that are not present in "
if overwrite: f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.") "to fix this."
with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f: )
f.write(tokenizers_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.",
"Run `make fix-copies` to fix this.",
)
if pt_dummies != actual_pt_dummies:
if overwrite:
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
with open(pt_file, "w", encoding="utf-8", newline="\n") as f:
f.write(pt_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
"Run `make fix-copies` to fix this.",
)
if tf_dummies != actual_tf_dummies:
if overwrite:
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
with open(tf_file, "w", encoding="utf-8", newline="\n") as f:
f.write(tf_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
"Run `make fix-copies` to fix this.",
)
if flax_dummies != actual_flax_dummies:
if overwrite:
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
with open(flax_file, "w", encoding="utf-8", newline="\n") as f:
f.write(flax_dummies)
else:
raise ValueError(
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
"Run `make fix-copies` to fix this.",
)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
def ignore_undocumented(name): def ignore_undocumented(name):
"""Rules to determine if `name` should be undocumented.""" """Rules to determine if `name` should be undocumented."""
# NOT DOCUMENTED ON PURPOSE. # NOT DOCUMENTED ON PURPOSE.
# Magic attributes are not documented.
if name.startswith("__"):
return True
# Constants uppercase are not documented. # Constants uppercase are not documented.
if name.isupper(): if name.isupper():
return True return True
@@ -459,7 +456,9 @@ def ignore_undocumented(name):
def check_all_objects_are_documented(): def check_all_objects_are_documented():
""" Check all models are properly documented.""" """ Check all models are properly documented."""
documented_objs = find_all_documented_objects() documented_objs = find_all_documented_objects()
undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)] modules = transformers._modules
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
if len(undocumented_objs) > 0: if len(undocumented_objs) > 0:
raise Exception( raise Exception(
"The following objects are in the public init so should be documented:\n - " "The following objects are in the public init so should be documented:\n - "