Transformers fast import part 2 (#9446)
* Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> * Add documentation * Styling Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -30,9 +30,8 @@ from multiprocessing import Pipe, Process, Queue
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
|
||||
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
from transformers import __version__ as version
|
||||
|
||||
from .. import AutoConfig, PretrainedConfig
|
||||
from .. import __version__ as version
|
||||
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
|
||||
from ..utils import logging
|
||||
from .benchmark_args_utils import BenchmarkArguments
|
||||
|
||||
@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
def run(self):
|
||||
if self._model_type == "albert":
|
||||
try:
|
||||
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "bert":
|
||||
try:
|
||||
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "funnel":
|
||||
try:
|
||||
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
convert_openai_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_xlm_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
elif self._model_type == "lxmert":
|
||||
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_lxmert_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def download_command_factory(args):
|
||||
@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
|
||||
self._force = force
|
||||
|
||||
def run(self):
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from ..models.auto import AutoModel, AutoTokenizer
|
||||
|
||||
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
import platform
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers import __version__ as version
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from .. import __version__ as version
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def info_command_factory(_):
|
||||
|
||||
@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -15,11 +15,9 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from transformers import SingleSentenceClassificationProcessor as Processor
|
||||
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..data import SingleSentenceClassificationProcessor as Processor
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from ..pipelines import TextClassificationPipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.add_new_model import AddNewModelCommand
|
||||
from transformers.commands.convert import ConvertCommand
|
||||
from transformers.commands.download import DownloadCommand
|
||||
from transformers.commands.env import EnvironmentCommand
|
||||
from transformers.commands.lfs import LfsCommands
|
||||
from transformers.commands.run import RunCommand
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.commands.user import UserCommands
|
||||
from .add_new_model import AddNewModelCommand
|
||||
from .convert import ConvertCommand
|
||||
from .download import DownloadCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .lfs import LfsCommands
|
||||
from .run import RunCommand
|
||||
from .serving import ServeCommand
|
||||
from .user import UserCommands
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -20,8 +20,9 @@ from getpass import getpass
|
||||
from typing import List, Union
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.hf_api import HfApi, HfFolder
|
||||
|
||||
from ..hf_api import HfApi, HfFolder
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
UPLOAD_MAX_FILES = 15
|
||||
|
||||
@@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from packaging.version import Version, parse
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.file_utils import ModelOutput
|
||||
from transformers.pipelines import Pipeline, pipeline
|
||||
from transformers.tokenization_utils import BatchEncoding
|
||||
from .file_utils import ModelOutput, is_tf_available, is_torch_available
|
||||
from .pipelines import Pipeline, pipeline
|
||||
from .tokenization_utils import BatchEncoding
|
||||
|
||||
|
||||
# This is the minimal required version to
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from transformers import (
|
||||
from . import (
|
||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||
@@ -87,15 +87,15 @@ from transformers import (
|
||||
is_torch_available,
|
||||
load_pytorch_checkpoint_in_tf2_model,
|
||||
)
|
||||
from transformers.file_utils import hf_bucket_url
|
||||
from transformers.utils import logging
|
||||
from .file_utils import hf_bucket_url
|
||||
from .utils import logging
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
from . import (
|
||||
AlbertForPreTraining,
|
||||
BartForConditionalGeneration,
|
||||
BertForPreTraining,
|
||||
|
||||
@@ -18,8 +18,9 @@ import argparse
|
||||
import os
|
||||
|
||||
import transformers
|
||||
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
||||
from transformers.utils import logging
|
||||
|
||||
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
||||
from .utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import (
|
||||
from . import (
|
||||
BertConfig,
|
||||
BertGenerationConfig,
|
||||
BertGenerationDecoder,
|
||||
|
||||
@@ -27,8 +27,7 @@ import math
|
||||
import re
|
||||
import string
|
||||
|
||||
from transformers import BasicTokenizer
|
||||
|
||||
from ...models.bert import BasicTokenizer
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
|
||||
@@ -17,15 +17,14 @@ import unittest
|
||||
|
||||
import timeout_decorator
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import require_torch
|
||||
from ..file_utils import cached_property, is_torch_available
|
||||
from ..testing_utils import require_torch
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import MarianConfig, MarianMTModel
|
||||
from ..models.marian import MarianConfig, MarianMTModel
|
||||
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -33,6 +33,7 @@ from dataclasses import fields
|
||||
from functools import partial, wraps
|
||||
from hashlib import sha256
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
from zipfile import ZipFile, is_zipfile
|
||||
@@ -41,7 +42,6 @@ import numpy as np
|
||||
from packaging import version
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
import importlib_metadata
|
||||
import requests
|
||||
from filelock import FileLock
|
||||
|
||||
@@ -50,6 +50,13 @@ from .hf_api import HfFolder
|
||||
from .utils import logging
|
||||
|
||||
|
||||
# The package importlib_metadata is in a different place, depending on the python version.
|
||||
if version.parse(sys.version) < version.parse("3.8"):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||
@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
|
||||
|
||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||
try:
|
||||
_scatter_version = importlib_metadata.version("torch_scatterr")
|
||||
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_scatter_available = False
|
||||
@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
|
||||
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
||||
"""
|
||||
return tuple(self[k] for k in self.keys())
|
||||
|
||||
|
||||
class _BaseLazyModule(ModuleType):
|
||||
"""
|
||||
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||
"""
|
||||
|
||||
# Very heavily inspired by optuna.integration._IntegrationModule
|
||||
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||
def __init__(self, name, import_structure):
|
||||
super().__init__(name)
|
||||
self._modules = set(import_structure.keys())
|
||||
self._class_to_module = {}
|
||||
for key, values in import_structure.items():
|
||||
for value in values:
|
||||
self._class_to_module[value] = key
|
||||
# Needed for autocompletion in an IDE
|
||||
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
|
||||
|
||||
# Needed for autocompletion in an IDE
|
||||
def __dir__(self):
|
||||
return super().__dir__() + self.__all__
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name in self._modules:
|
||||
value = self._get_module(name)
|
||||
elif name in self._class_to_module.keys():
|
||||
module = self._get_module(self._class_to_module[name])
|
||||
value = getattr(module, name)
|
||||
else:
|
||||
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||
|
||||
setattr(self, name, value)
|
||||
return value
|
||||
|
||||
def _get_module(self, module_name: str) -> ModuleType:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
# comet_ml requires to be imported before any ML frameworks
|
||||
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||
if _has_comet:
|
||||
try:
|
||||
import comet_ml # noqa: F401
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import (
|
||||
albert,
|
||||
auto,
|
||||
bart,
|
||||
barthez,
|
||||
bert,
|
||||
bert_generation,
|
||||
bert_japanese,
|
||||
bertweet,
|
||||
blenderbot,
|
||||
blenderbot_small,
|
||||
camembert,
|
||||
ctrl,
|
||||
deberta,
|
||||
dialogpt,
|
||||
distilbert,
|
||||
dpr,
|
||||
electra,
|
||||
encoder_decoder,
|
||||
flaubert,
|
||||
fsmt,
|
||||
funnel,
|
||||
gpt2,
|
||||
herbert,
|
||||
layoutlm,
|
||||
led,
|
||||
longformer,
|
||||
lxmert,
|
||||
marian,
|
||||
mbart,
|
||||
mmbt,
|
||||
mobilebert,
|
||||
mpnet,
|
||||
mt5,
|
||||
openai,
|
||||
pegasus,
|
||||
phobert,
|
||||
prophetnet,
|
||||
rag,
|
||||
reformer,
|
||||
retribert,
|
||||
roberta,
|
||||
squeezebert,
|
||||
t5,
|
||||
tapas,
|
||||
transfo_xl,
|
||||
xlm,
|
||||
xlm_roberta,
|
||||
xlnet,
|
||||
)
|
||||
|
||||
@@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -23,15 +23,9 @@ import fairseq
|
||||
import torch
|
||||
from packaging import version
|
||||
|
||||
from transformers import (
|
||||
BartConfig,
|
||||
BartForConditionalGeneration,
|
||||
BartForSequenceClassification,
|
||||
BartModel,
|
||||
BartTokenizer,
|
||||
)
|
||||
from transformers.models.bart.modeling_bart import _make_linear_from_emb
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
|
||||
from .modeling_bart import _make_linear_from_emb
|
||||
|
||||
|
||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import add_start_docstrings
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from transformers import add_start_docstrings
|
||||
|
||||
from ...file_utils import add_start_docstrings
|
||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||
from ...utils import logging
|
||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||
|
||||
@@ -28,8 +28,8 @@ import re
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertModel
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import BertConfig, BertModel
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -22,7 +22,7 @@ import numpy as np
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
|
||||
from transformers import BertModel
|
||||
from . import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||
|
||||
@@ -18,8 +18,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BartConfig, BartForConditionalGeneration
|
||||
from transformers.utils import logging
|
||||
from ...models.bart import BartConfig, BartForConditionalGeneration
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from transformers.file_utils import WEIGHTS_NAME
|
||||
from ...file_utils import WEIGHTS_NAME
|
||||
|
||||
|
||||
DIALOGPT_MODELS = ["small", "medium", "large"]
|
||||
|
||||
@@ -19,7 +19,8 @@ from pathlib import Path
|
||||
import torch
|
||||
from torch.serialization import default_restore_location
|
||||
|
||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
from ...models.bert import BertConfig
|
||||
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||
|
||||
|
||||
CheckpointState = collections.namedtuple(
|
||||
|
||||
@@ -19,8 +19,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -31,9 +31,10 @@ import torch
|
||||
from fairseq import hub_utils
|
||||
from fairseq.data.dictionary import Dictionary
|
||||
|
||||
from transformers import WEIGHTS_NAME, logging
|
||||
from transformers.models.fsmt import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...file_utils import WEIGHTS_NAME
|
||||
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||
from ...utils import logging
|
||||
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
||||
|
||||
|
||||
logging.set_verbosity_warning()
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -19,8 +19,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||
from transformers.utils import logging
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -20,7 +20,7 @@ import argparse
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
|
||||
from transformers import LongformerForQuestionAnswering, LongformerModel
|
||||
from . import LongformerForQuestionAnswering, LongformerModel
|
||||
|
||||
|
||||
class LightningModel(pl.LightningModule):
|
||||
|
||||
@@ -20,7 +20,7 @@ import logging
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
@@ -17,7 +17,7 @@ import os
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers.models.marian.convert_marian_to_pytorch import (
|
||||
from .convert_marian_to_pytorch import (
|
||||
FRONT_MATTER_TEMPLATE,
|
||||
_parse_readme,
|
||||
convert_all_sentencepiece_models,
|
||||
|
||||
@@ -26,8 +26,8 @@ import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
||||
from transformers.hf_api import HfApi
|
||||
from ...hf_api import HfApi
|
||||
from . import MarianConfig, MarianMTModel, MarianTokenizer
|
||||
|
||||
|
||||
def remove_suffix(text: str, suffix: str):
|
||||
|
||||
@@ -16,9 +16,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import BartForConditionalGeneration, MBartConfig
|
||||
|
||||
from ..bart import BartForConditionalGeneration
|
||||
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||
from . import MBartConfig
|
||||
|
||||
|
||||
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
||||
|
||||
@@ -16,8 +16,8 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -19,8 +19,9 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||
from transformers.utils import logging
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -22,8 +22,8 @@ import tensorflow as tf
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
|
||||
from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||
from .configuration_pegasus import DEFAULTS, task_specific_params
|
||||
|
||||
|
||||
PATTERNS = [
|
||||
|
||||
@@ -19,8 +19,6 @@ import argparse
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||
|
||||
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
||||
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
||||
from transformers_old.modeling_prophetnet import (
|
||||
@@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import (
|
||||
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
||||
)
|
||||
|
||||
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -21,8 +21,8 @@ import pickle
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from transformers import ReformerConfig, ReformerModelWithLMHead
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import ReformerConfig, ReformerModelWithLMHead
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from packaging import version
|
||||
|
||||
from transformers.models.bert.modeling_bert import (
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.models.roberta.modeling_roberta import (
|
||||
RobertaConfig,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||
from ...utils import logging
|
||||
from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
|
||||
@@ -17,8 +17,8 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
from transformers.utils import logging
|
||||
from ...utils import logging
|
||||
from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -22,8 +22,6 @@ from typing import Tuple
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers.modeling_tf_utils import TFWrappedEmbeddings
|
||||
|
||||
from ...activations_tf import get_tf_activation
|
||||
from ...file_utils import (
|
||||
DUMMY_INPUTS,
|
||||
@@ -42,6 +40,7 @@ from ...modeling_tf_utils import (
|
||||
TFCausalLanguageModelingLoss,
|
||||
TFPreTrainedModel,
|
||||
TFSharedEmbeddings,
|
||||
TFWrappedEmbeddings,
|
||||
input_processing,
|
||||
keras_serializable,
|
||||
shape_list,
|
||||
|
||||
@@ -17,16 +17,16 @@
|
||||
|
||||
import argparse
|
||||
|
||||
from transformers.models.tapas.modeling_tapas import (
|
||||
from ...utils import logging
|
||||
from . import (
|
||||
TapasConfig,
|
||||
TapasForMaskedLM,
|
||||
TapasForQuestionAnswering,
|
||||
TapasForSequenceClassification,
|
||||
TapasModel,
|
||||
TapasTokenizer,
|
||||
load_tf_weights_in_tapas,
|
||||
)
|
||||
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
from transformers import add_end_docstrings
|
||||
|
||||
from ...file_utils import is_pandas_available
|
||||
from ...file_utils import add_end_docstrings, is_pandas_available
|
||||
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||
from ...tokenization_utils_base import (
|
||||
ENCODE_KWARGS_DOCSTRING,
|
||||
|
||||
@@ -22,16 +22,11 @@ import sys
|
||||
|
||||
import torch
|
||||
|
||||
import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils
|
||||
from transformers import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
TransfoXLConfig,
|
||||
TransfoXLLMHeadModel,
|
||||
load_tf_weights_in_transfo_xl,
|
||||
)
|
||||
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
from transformers.utils import logging
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||
from . import tokenization_transfo_xl as data_utils
|
||||
from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -21,9 +21,9 @@ import json
|
||||
import numpy
|
||||
import torch
|
||||
|
||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
||||
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
|
||||
from transformers.utils import logging
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from .tokenization_xlm import VOCAB_FILES_NAMES
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
|
||||
@@ -20,16 +20,15 @@ import os
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||
from ...utils import logging
|
||||
from . import (
|
||||
XLNetConfig,
|
||||
XLNetForQuestionAnswering,
|
||||
XLNetForSequenceClassification,
|
||||
XLNetLMHeadModel,
|
||||
load_tf_weights_in_xlnet,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
GLUE_TASKS_NUM_LABELS = {
|
||||
|
||||
@@ -24,71 +24,147 @@
|
||||
##
|
||||
## Put '## COMMENT' to comment on the file.
|
||||
|
||||
# To replace in: "src/transformers/__init__.py"
|
||||
# Below: " # PyTorch models structure" if generating PyTorch
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||
[
|
||||
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
|
||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
|
||||
"{{cookiecutter.camelcase_modelname}}Layer",
|
||||
"{{cookiecutter.camelcase_modelname}}Model",
|
||||
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
|
||||
]
|
||||
)
|
||||
{% else %}
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||
[
|
||||
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||
"{{cookiecutter.camelcase_modelname}}Model",
|
||||
]
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " # TensorFlow models structure" if generating TensorFlow
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||
[
|
||||
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
|
||||
"TF{{cookiecutter.camelcase_modelname}}Layer",
|
||||
"TF{{cookiecutter.camelcase_modelname}}Model",
|
||||
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
]
|
||||
)
|
||||
{% else %}
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||
[
|
||||
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||
"TF{{cookiecutter.camelcase_modelname}}Model",
|
||||
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||
]
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: " # Fast tokenizers"
|
||||
# Replace with:
|
||||
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
|
||||
# End.
|
||||
|
||||
# Below: " # Models"
|
||||
# Replace with:
|
||||
"models.{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config", "{{cookiecutter.camelcase_modelname}}Tokenizer"],
|
||||
# End.
|
||||
|
||||
# To replace in: "src/transformers/__init__.py"
|
||||
# Below: "if is_torch_available():" if generating PyTorch
|
||||
# Below: " if is_torch_available():" if generating PyTorch
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Layer,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
||||
)
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Layer,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
||||
)
|
||||
{% else %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
{{cookiecutter.camelcase_modelname}}Model,
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "if is_tf_available():" if generating TensorFlow
|
||||
# Below: " if is_tf_available():" if generating TensorFlow
|
||||
# Replace with:
|
||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}Layer,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||
TF{{cookiecutter.camelcase_modelname}}Layer,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% else %}
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||
)
|
||||
{% endif -%}
|
||||
# End.
|
||||
|
||||
# Below: "if is_tokenizers_available():"
|
||||
# Below: " if is_tokenizers_available():"
|
||||
# Replace with:
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||
# End.
|
||||
|
||||
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||
# Below: " from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||
# Replace with:
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||
# End.
|
||||
|
||||
|
||||
|
||||
# To replace in: "src/transformers/models/__init__.py"
|
||||
# Below: "from . import ("
|
||||
# Replace with:
|
||||
{{cookiecutter.lowercase_modelname}},
|
||||
# End.
|
||||
|
||||
|
||||
# To replace in: "src/transformers/models/auto/configuration_auto.py"
|
||||
# Below: "# Add configs here"
|
||||
# Replace with:
|
||||
|
||||
@@ -23,237 +23,79 @@ import re
|
||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||
|
||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||
|
||||
|
||||
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
|
||||
|
||||
|
||||
DUMMY_CONSTANT = """
|
||||
{0} = None
|
||||
"""
|
||||
|
||||
DUMMY_PT_PRETRAINED_CLASS = """
|
||||
DUMMY_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_CLASS = """
|
||||
DUMMY_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_pytorch(self)
|
||||
requires_{1}(self)
|
||||
"""
|
||||
|
||||
DUMMY_PT_FUNCTION = """
|
||||
DUMMY_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_pytorch({0})
|
||||
requires_{1}({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_TF_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tf(self)
|
||||
"""
|
||||
|
||||
DUMMY_TF_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_tf({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_FLAX_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
"""
|
||||
|
||||
DUMMY_FLAX_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_flax(self)
|
||||
"""
|
||||
|
||||
DUMMY_FLAX_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_flax({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
"""
|
||||
|
||||
DUMMY_SENTENCEPIECE_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_sentencepiece(self)
|
||||
"""
|
||||
|
||||
DUMMY_SENTENCEPIECE_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_sentencepiece({0})
|
||||
"""
|
||||
|
||||
|
||||
DUMMY_TOKENIZERS_PRETRAINED_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
"""
|
||||
|
||||
DUMMY_TOKENIZERS_CLASS = """
|
||||
class {0}:
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_tokenizers(self)
|
||||
"""
|
||||
|
||||
DUMMY_TOKENIZERS_FUNCTION = """
|
||||
def {0}(*args, **kwargs):
|
||||
requires_tokenizers({0})
|
||||
"""
|
||||
|
||||
# Map all these to dummy type
|
||||
|
||||
DUMMY_PRETRAINED_CLASS = {
|
||||
"pt": DUMMY_PT_PRETRAINED_CLASS,
|
||||
"tf": DUMMY_TF_PRETRAINED_CLASS,
|
||||
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
|
||||
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
|
||||
}
|
||||
|
||||
DUMMY_CLASS = {
|
||||
"pt": DUMMY_PT_CLASS,
|
||||
"tf": DUMMY_TF_CLASS,
|
||||
"flax": DUMMY_FLAX_CLASS,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
|
||||
"tokenizers": DUMMY_TOKENIZERS_CLASS,
|
||||
}
|
||||
|
||||
DUMMY_FUNCTION = {
|
||||
"pt": DUMMY_PT_FUNCTION,
|
||||
"tf": DUMMY_TF_FUNCTION,
|
||||
"flax": DUMMY_FLAX_FUNCTION,
|
||||
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
|
||||
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
|
||||
}
|
||||
|
||||
|
||||
def read_init():
|
||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# Get to the point we do the actual imports for type checking
|
||||
line_index = 0
|
||||
# Find where the SentencePiece imports begin
|
||||
sentencepiece_objects = []
|
||||
while not lines[line_index].startswith("if is_sentencepiece_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add SentencePiece objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
sentencepiece_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
sentencepiece_objects.append(line[8:-2])
|
||||
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||
line_index += 1
|
||||
|
||||
# Find where the Tokenizers imports begin
|
||||
tokenizers_objects = []
|
||||
while not lines[line_index].startswith("if is_tokenizers_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
backend_specific_objects = {}
|
||||
# Go through the end of the file
|
||||
while line_index < len(lines):
|
||||
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||
if _re_test_backend.search(lines[line_index]) is not None:
|
||||
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add Tokenizers objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
tokenizers_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
tokenizers_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
# Ignore if backend isn't tracked for dummies.
|
||||
if backend not in BACKENDS:
|
||||
continue
|
||||
|
||||
# Find where the PyTorch imports begin
|
||||
pt_objects = []
|
||||
while not lines[line_index].startswith("if is_torch_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
objects = []
|
||||
# Until we unindent, add backend objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||
line = lines[line_index]
|
||||
single_line_import_search = _re_single_line_import.search(line)
|
||||
if single_line_import_search is not None:
|
||||
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||
elif line.startswith(" " * 12):
|
||||
objects.append(line[12:-2])
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
pt_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
pt_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
backend_specific_objects[backend] = objects
|
||||
else:
|
||||
line_index += 1
|
||||
|
||||
# Find where the TF imports begin
|
||||
tf_objects = []
|
||||
while not lines[line_index].startswith("if is_tf_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
tf_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
tf_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
# Find where the FLAX imports begin
|
||||
flax_objects = []
|
||||
while not lines[line_index].startswith("if is_flax_available():"):
|
||||
line_index += 1
|
||||
line_index += 1
|
||||
|
||||
# Until we unindent, add PyTorch objects to the list
|
||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
||||
line = lines[line_index]
|
||||
search = _re_single_line_import.search(line)
|
||||
if search is not None:
|
||||
flax_objects += search.groups()[0].split(", ")
|
||||
elif line.startswith(" "):
|
||||
flax_objects.append(line[8:-2])
|
||||
line_index += 1
|
||||
|
||||
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
|
||||
return backend_specific_objects
|
||||
|
||||
|
||||
def create_dummy_object(name, type="pt"):
|
||||
def create_dummy_object(name, backend_name):
|
||||
""" Create the code for the dummy object corresponding to `name`."""
|
||||
_pretrained = [
|
||||
"Config" "ForCausalLM",
|
||||
@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"):
|
||||
"Model",
|
||||
"Tokenizer",
|
||||
]
|
||||
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
|
||||
if name.isupper():
|
||||
return DUMMY_CONSTANT.format(name)
|
||||
elif name.islower():
|
||||
return (DUMMY_FUNCTION[type]).format(name)
|
||||
return DUMMY_FUNCTION.format(name, backend_name)
|
||||
else:
|
||||
is_pretrained = False
|
||||
for part in _pretrained:
|
||||
@@ -278,114 +119,61 @@ def create_dummy_object(name, type="pt"):
|
||||
is_pretrained = True
|
||||
break
|
||||
if is_pretrained:
|
||||
template = DUMMY_PRETRAINED_CLASS[type]
|
||||
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||
else:
|
||||
template = DUMMY_CLASS[type]
|
||||
return template.format(name)
|
||||
return DUMMY_CLASS.format(name, backend_name)
|
||||
|
||||
|
||||
def create_dummy_files():
|
||||
""" Create the content of the dummy files. """
|
||||
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init()
|
||||
backend_specific_objects = read_init()
|
||||
# For special correspondence backend to module name as used in the function requires_modulename
|
||||
module_names = {"torch": "pytorch"}
|
||||
dummy_files = {}
|
||||
|
||||
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
|
||||
sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects])
|
||||
for backend, objects in backend_specific_objects.items():
|
||||
backend_name = module_names.get(backend, backend)
|
||||
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
|
||||
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||
dummy_files[backend] = dummy_file
|
||||
|
||||
tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n"
|
||||
tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects])
|
||||
|
||||
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
|
||||
pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects])
|
||||
|
||||
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
||||
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
|
||||
|
||||
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||
flax_dummies += "from ..file_utils import requires_flax\n\n"
|
||||
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
|
||||
|
||||
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
|
||||
return dummy_files
|
||||
|
||||
|
||||
def check_dummies(overwrite=False):
|
||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files()
|
||||
dummy_files = create_dummy_files()
|
||||
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||
short_names = {"torch": "pt"}
|
||||
|
||||
# Locate actual dummy modules and read their content.
|
||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
|
||||
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
|
||||
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
||||
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
||||
flax_file = os.path.join(path, "dummy_flax_objects.py")
|
||||
dummy_file_paths = {
|
||||
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||
for backend in dummy_files.keys()
|
||||
}
|
||||
|
||||
with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_sentencepiece_dummies = f.read()
|
||||
with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_tokenizers_dummies = f.read()
|
||||
with open(pt_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_pt_dummies = f.read()
|
||||
with open(tf_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_tf_dummies = f.read()
|
||||
with open(flax_file, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_flax_dummies = f.read()
|
||||
actual_dummies = {}
|
||||
for backend, file_path in dummy_file_paths.items():
|
||||
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||
actual_dummies[backend] = f.read()
|
||||
|
||||
if sentencepiece_dummies != actual_sentencepiece_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.")
|
||||
with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(sentencepiece_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if tokenizers_dummies != actual_tokenizers_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.")
|
||||
with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(tokenizers_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if pt_dummies != actual_pt_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
|
||||
with open(pt_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(pt_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if tf_dummies != actual_tf_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
|
||||
with open(tf_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(tf_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
|
||||
if flax_dummies != actual_flax_dummies:
|
||||
if overwrite:
|
||||
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
|
||||
with open(flax_file, "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(flax_dummies)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
|
||||
"Run `make fix-copies` to fix this.",
|
||||
)
|
||||
for backend in dummy_files.keys():
|
||||
if dummy_files[backend] != actual_dummies[backend]:
|
||||
if overwrite:
|
||||
print(
|
||||
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||
"__init__ has new objects."
|
||||
)
|
||||
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||
f.write(dummy_files[backend])
|
||||
else:
|
||||
raise ValueError(
|
||||
"The main __init__ has objects that are not present in "
|
||||
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||
"to fix this."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
||||
def ignore_undocumented(name):
|
||||
"""Rules to determine if `name` should be undocumented."""
|
||||
# NOT DOCUMENTED ON PURPOSE.
|
||||
# Magic attributes are not documented.
|
||||
if name.startswith("__"):
|
||||
return True
|
||||
# Constants uppercase are not documented.
|
||||
if name.isupper():
|
||||
return True
|
||||
@@ -459,7 +456,9 @@ def ignore_undocumented(name):
|
||||
def check_all_objects_are_documented():
|
||||
""" Check all models are properly documented."""
|
||||
documented_objs = find_all_documented_objects()
|
||||
undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)]
|
||||
modules = transformers._modules
|
||||
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||
if len(undocumented_objs) > 0:
|
||||
raise Exception(
|
||||
"The following objects are in the public init so should be documented:\n - "
|
||||
|
||||
Reference in New Issue
Block a user