Transformers fast import part 2 (#9446)
* Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> * Add documentation * Styling Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -30,9 +30,8 @@ from multiprocessing import Pipe, Process, Queue
|
|||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
|
from typing import Callable, Iterable, List, NamedTuple, Optional, Union
|
||||||
|
|
||||||
from transformers import AutoConfig, PretrainedConfig
|
from .. import AutoConfig, PretrainedConfig
|
||||||
from transformers import __version__ as version
|
from .. import __version__ as version
|
||||||
|
|
||||||
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
|
from ..file_utils import is_psutil_available, is_py3nvml_available, is_tf_available, is_torch_available
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|||||||
@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -14,9 +14,8 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
def convert_command_factory(args: Namespace):
|
def convert_command_factory(args: Namespace):
|
||||||
@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
def run(self):
|
def run(self):
|
||||||
if self._model_type == "albert":
|
if self._model_type == "albert":
|
||||||
try:
|
try:
|
||||||
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_tf_checkpoint_to_pytorch,
|
convert_tf_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "bert":
|
elif self._model_type == "bert":
|
||||||
try:
|
try:
|
||||||
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_tf_checkpoint_to_pytorch,
|
convert_tf_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "funnel":
|
elif self._model_type == "funnel":
|
||||||
try:
|
try:
|
||||||
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_tf_checkpoint_to_pytorch,
|
convert_tf_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
|
|
||||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "gpt":
|
elif self._model_type == "gpt":
|
||||||
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_openai_checkpoint_to_pytorch,
|
convert_openai_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "transfo_xl":
|
elif self._model_type == "transfo_xl":
|
||||||
try:
|
try:
|
||||||
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_transfo_xl_checkpoint_to_pytorch,
|
convert_transfo_xl_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
)
|
)
|
||||||
elif self._model_type == "gpt2":
|
elif self._model_type == "gpt2":
|
||||||
try:
|
try:
|
||||||
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_gpt2_checkpoint_to_pytorch,
|
convert_gpt2_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||||
elif self._model_type == "xlnet":
|
elif self._model_type == "xlnet":
|
||||||
try:
|
try:
|
||||||
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||||
convert_xlnet_checkpoint_to_pytorch,
|
convert_xlnet_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
|
|||||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||||
)
|
)
|
||||||
elif self._model_type == "xlm":
|
elif self._model_type == "xlm":
|
||||||
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||||
convert_xlm_checkpoint_to_pytorch,
|
convert_xlm_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
|
|
||||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||||
elif self._model_type == "lxmert":
|
elif self._model_type == "lxmert":
|
||||||
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||||
convert_lxmert_checkpoint_to_pytorch,
|
convert_lxmert_checkpoint_to_pytorch,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
def download_command_factory(args):
|
def download_command_factory(args):
|
||||||
@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
|
|||||||
self._force = force
|
self._force = force
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
from transformers import AutoModel, AutoTokenizer
|
from ..models.auto import AutoModel, AutoTokenizer
|
||||||
|
|
||||||
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||||
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||||
|
|||||||
@@ -15,9 +15,9 @@
|
|||||||
import platform
|
import platform
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers import __version__ as version
|
from .. import __version__ as version
|
||||||
from transformers import is_tf_available, is_torch_available
|
from ..file_utils import is_tf_available, is_torch_available
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
def info_command_factory(_):
|
def info_command_factory(_):
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
|
|||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@@ -14,10 +14,9 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|||||||
@@ -15,11 +15,9 @@
|
|||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from transformers import Pipeline
|
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -15,11 +15,11 @@
|
|||||||
import os
|
import os
|
||||||
from argparse import ArgumentParser, Namespace
|
from argparse import ArgumentParser, Namespace
|
||||||
|
|
||||||
from transformers import SingleSentenceClassificationProcessor as Processor
|
from ..data import SingleSentenceClassificationProcessor as Processor
|
||||||
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
from ..file_utils import is_tf_available, is_torch_available
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
from ..pipelines import TextClassificationPipeline
|
||||||
|
|
||||||
from ..utils import logging
|
from ..utils import logging
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
if not is_tf_available() and not is_torch_available():
|
if not is_tf_available() and not is_torch_available():
|
||||||
|
|||||||
@@ -15,14 +15,14 @@
|
|||||||
|
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
|
|
||||||
from transformers.commands.add_new_model import AddNewModelCommand
|
from .add_new_model import AddNewModelCommand
|
||||||
from transformers.commands.convert import ConvertCommand
|
from .convert import ConvertCommand
|
||||||
from transformers.commands.download import DownloadCommand
|
from .download import DownloadCommand
|
||||||
from transformers.commands.env import EnvironmentCommand
|
from .env import EnvironmentCommand
|
||||||
from transformers.commands.lfs import LfsCommands
|
from .lfs import LfsCommands
|
||||||
from transformers.commands.run import RunCommand
|
from .run import RunCommand
|
||||||
from transformers.commands.serving import ServeCommand
|
from .serving import ServeCommand
|
||||||
from transformers.commands.user import UserCommands
|
from .user import UserCommands
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
@@ -20,8 +20,9 @@ from getpass import getpass
|
|||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
|
|
||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers.commands import BaseTransformersCLICommand
|
|
||||||
from transformers.hf_api import HfApi, HfFolder
|
from ..hf_api import HfApi, HfFolder
|
||||||
|
from . import BaseTransformersCLICommand
|
||||||
|
|
||||||
|
|
||||||
UPLOAD_MAX_FILES = 15
|
UPLOAD_MAX_FILES = 15
|
||||||
|
|||||||
@@ -19,10 +19,9 @@ from typing import Dict, List, Optional, Tuple
|
|||||||
|
|
||||||
from packaging.version import Version, parse
|
from packaging.version import Version, parse
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from .file_utils import ModelOutput, is_tf_available, is_torch_available
|
||||||
from transformers.file_utils import ModelOutput
|
from .pipelines import Pipeline, pipeline
|
||||||
from transformers.pipelines import Pipeline, pipeline
|
from .tokenization_utils import BatchEncoding
|
||||||
from transformers.tokenization_utils import BatchEncoding
|
|
||||||
|
|
||||||
|
|
||||||
# This is the minimal required version to
|
# This is the minimal required version to
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from transformers import (
|
from . import (
|
||||||
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
BART_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
BERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
|
||||||
@@ -87,15 +87,15 @@ from transformers import (
|
|||||||
is_torch_available,
|
is_torch_available,
|
||||||
load_pytorch_checkpoint_in_tf2_model,
|
load_pytorch_checkpoint_in_tf2_model,
|
||||||
)
|
)
|
||||||
from transformers.file_utils import hf_bucket_url
|
from .file_utils import hf_bucket_url
|
||||||
from transformers.utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from . import (
|
||||||
AlbertForPreTraining,
|
AlbertForPreTraining,
|
||||||
BartForConditionalGeneration,
|
BartForConditionalGeneration,
|
||||||
BertForPreTraining,
|
BertForPreTraining,
|
||||||
|
|||||||
@@ -18,8 +18,9 @@ import argparse
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers.convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
|
||||||
from transformers.utils import logging
|
from .convert_slow_tokenizer import SLOW_TO_FAST_CONVERTERS
|
||||||
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import (
|
from . import (
|
||||||
BertConfig,
|
BertConfig,
|
||||||
BertGenerationConfig,
|
BertGenerationConfig,
|
||||||
BertGenerationDecoder,
|
BertGenerationDecoder,
|
||||||
|
|||||||
@@ -27,8 +27,7 @@ import math
|
|||||||
import re
|
import re
|
||||||
import string
|
import string
|
||||||
|
|
||||||
from transformers import BasicTokenizer
|
from ...models.bert import BasicTokenizer
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,15 +17,14 @@ import unittest
|
|||||||
|
|
||||||
import timeout_decorator
|
import timeout_decorator
|
||||||
|
|
||||||
from transformers import is_torch_available
|
from ..file_utils import cached_property, is_torch_available
|
||||||
from transformers.file_utils import cached_property
|
from ..testing_utils import require_torch
|
||||||
from transformers.testing_utils import require_torch
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import MarianConfig, MarianMTModel
|
from ..models.marian import MarianConfig, MarianMTModel
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from dataclasses import fields
|
|||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
from typing import Any, BinaryIO, Dict, List, Optional, Tuple, Union
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from zipfile import ZipFile, is_zipfile
|
from zipfile import ZipFile, is_zipfile
|
||||||
@@ -41,7 +42,6 @@ import numpy as np
|
|||||||
from packaging import version
|
from packaging import version
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
import importlib_metadata
|
|
||||||
import requests
|
import requests
|
||||||
from filelock import FileLock
|
from filelock import FileLock
|
||||||
|
|
||||||
@@ -50,6 +50,13 @@ from .hf_api import HfFolder
|
|||||||
from .utils import logging
|
from .utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
# The package importlib_metadata is in a different place, depending on the python version.
|
||||||
|
if version.parse(sys.version) < version.parse("3.8"):
|
||||||
|
import importlib_metadata
|
||||||
|
else:
|
||||||
|
import importlib.metadata as importlib_metadata
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
ENV_VARS_TRUE_VALUES = {"1", "ON", "YES"}
|
||||||
@@ -130,7 +137,7 @@ except importlib_metadata.PackageNotFoundError:
|
|||||||
|
|
||||||
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
_scatter_available = importlib.util.find_spec("torch_scatter") is not None
|
||||||
try:
|
try:
|
||||||
_scatter_version = importlib_metadata.version("torch_scatterr")
|
_scatter_version = importlib_metadata.version("torch_scatter")
|
||||||
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
logger.debug(f"Successfully imported torch-scatter version {_scatter_version}")
|
||||||
except importlib_metadata.PackageNotFoundError:
|
except importlib_metadata.PackageNotFoundError:
|
||||||
_scatter_available = False
|
_scatter_available = False
|
||||||
@@ -1415,3 +1422,40 @@ class ModelOutput(OrderedDict):
|
|||||||
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
Convert self to a tuple containing all the attributes/keys that are not ``None``.
|
||||||
"""
|
"""
|
||||||
return tuple(self[k] for k in self.keys())
|
return tuple(self[k] for k in self.keys())
|
||||||
|
|
||||||
|
|
||||||
|
class _BaseLazyModule(ModuleType):
|
||||||
|
"""
|
||||||
|
Module class that surfaces all objects but only performs associated imports when the objects are requested.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Very heavily inspired by optuna.integration._IntegrationModule
|
||||||
|
# https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
|
||||||
|
def __init__(self, name, import_structure):
|
||||||
|
super().__init__(name)
|
||||||
|
self._modules = set(import_structure.keys())
|
||||||
|
self._class_to_module = {}
|
||||||
|
for key, values in import_structure.items():
|
||||||
|
for value in values:
|
||||||
|
self._class_to_module[value] = key
|
||||||
|
# Needed for autocompletion in an IDE
|
||||||
|
self.__all__ = list(import_structure.keys()) + sum(import_structure.values(), [])
|
||||||
|
|
||||||
|
# Needed for autocompletion in an IDE
|
||||||
|
def __dir__(self):
|
||||||
|
return super().__dir__() + self.__all__
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name in self._modules:
|
||||||
|
value = self._get_module(name)
|
||||||
|
elif name in self._class_to_module.keys():
|
||||||
|
module = self._get_module(self._class_to_module[name])
|
||||||
|
value = getattr(module, name)
|
||||||
|
else:
|
||||||
|
raise AttributeError(f"module {self.__name__} has no attribute {name}")
|
||||||
|
|
||||||
|
setattr(self, name, value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def _get_module(self, module_name: str) -> ModuleType:
|
||||||
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
# comet_ml requires to be imported before any ML frameworks
|
# comet_ml requires to be imported before any ML frameworks
|
||||||
_has_comet = importlib.util.find_spec("comet_ml") and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
_has_comet = importlib.util.find_spec("comet_ml") is not None and os.getenv("COMET_MODE", "").upper() != "DISABLED"
|
||||||
if _has_comet:
|
if _has_comet:
|
||||||
try:
|
try:
|
||||||
import comet_ml # noqa: F401
|
import comet_ml # noqa: F401
|
||||||
|
|||||||
@@ -0,0 +1,68 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||||
|
# module, but to preserve other warnings. So, don't check this module at all.
|
||||||
|
|
||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from . import (
|
||||||
|
albert,
|
||||||
|
auto,
|
||||||
|
bart,
|
||||||
|
barthez,
|
||||||
|
bert,
|
||||||
|
bert_generation,
|
||||||
|
bert_japanese,
|
||||||
|
bertweet,
|
||||||
|
blenderbot,
|
||||||
|
blenderbot_small,
|
||||||
|
camembert,
|
||||||
|
ctrl,
|
||||||
|
deberta,
|
||||||
|
dialogpt,
|
||||||
|
distilbert,
|
||||||
|
dpr,
|
||||||
|
electra,
|
||||||
|
encoder_decoder,
|
||||||
|
flaubert,
|
||||||
|
fsmt,
|
||||||
|
funnel,
|
||||||
|
gpt2,
|
||||||
|
herbert,
|
||||||
|
layoutlm,
|
||||||
|
led,
|
||||||
|
longformer,
|
||||||
|
lxmert,
|
||||||
|
marian,
|
||||||
|
mbart,
|
||||||
|
mmbt,
|
||||||
|
mobilebert,
|
||||||
|
mpnet,
|
||||||
|
mt5,
|
||||||
|
openai,
|
||||||
|
pegasus,
|
||||||
|
phobert,
|
||||||
|
prophetnet,
|
||||||
|
rag,
|
||||||
|
reformer,
|
||||||
|
retribert,
|
||||||
|
roberta,
|
||||||
|
squeezebert,
|
||||||
|
t5,
|
||||||
|
tapas,
|
||||||
|
transfo_xl,
|
||||||
|
xlm,
|
||||||
|
xlm_roberta,
|
||||||
|
xlnet,
|
||||||
|
)
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import AlbertConfig, AlbertForPreTraining, load_tf_weights_in_albert
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -23,15 +23,9 @@ import fairseq
|
|||||||
import torch
|
import torch
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers import (
|
from ...utils import logging
|
||||||
BartConfig,
|
from . import BartConfig, BartForConditionalGeneration, BartForSequenceClassification, BartModel, BartTokenizer
|
||||||
BartForConditionalGeneration,
|
from .modeling_bart import _make_linear_from_emb
|
||||||
BartForSequenceClassification,
|
|
||||||
BartModel,
|
|
||||||
BartTokenizer,
|
|
||||||
)
|
|
||||||
from transformers.models.bart.modeling_bart import _make_linear_from_emb
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
FAIRSEQ_MODELS = ["bart.large", "bart.large.mnli", "bart.large.cnn", "bart_xsum/model.pt"]
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import add_start_docstrings
|
from ...file_utils import add_start_docstrings
|
||||||
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.tokenization_roberta import RobertaTokenizer
|
from ..roberta.tokenization_roberta import RobertaTokenizer
|
||||||
|
|||||||
@@ -15,8 +15,7 @@
|
|||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from transformers import add_start_docstrings
|
from ...file_utils import add_start_docstrings
|
||||||
|
|
||||||
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
from ...tokenization_utils_base import PREPARE_SEQ2SEQ_BATCH_DOCSTRING, BatchEncoding
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
from ..roberta.tokenization_roberta_fast import RobertaTokenizerFast
|
||||||
|
|||||||
@@ -28,8 +28,8 @@ import re
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertConfig, BertModel
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import BertConfig, BertModel
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import BertConfig, BertForPreTraining, load_tf_weights_in_bert
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import numpy as np
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertModel
|
from . import BertModel
|
||||||
|
|
||||||
|
|
||||||
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
def convert_pytorch_checkpoint_to_tf(model: BertModel, ckpt_dir: str, model_name: str):
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BartConfig, BartForConditionalGeneration
|
from ...models.bart import BartConfig, BartForConditionalGeneration
|
||||||
from transformers.utils import logging
|
from ...utils import logging
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers.file_utils import WEIGHTS_NAME
|
from ...file_utils import WEIGHTS_NAME
|
||||||
|
|
||||||
|
|
||||||
DIALOGPT_MODELS = ["small", "medium", "large"]
|
DIALOGPT_MODELS = ["small", "medium", "large"]
|
||||||
|
|||||||
@@ -19,7 +19,8 @@ from pathlib import Path
|
|||||||
import torch
|
import torch
|
||||||
from torch.serialization import default_restore_location
|
from torch.serialization import default_restore_location
|
||||||
|
|
||||||
from transformers import BertConfig, DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
from ...models.bert import BertConfig
|
||||||
|
from . import DPRConfig, DPRContextEncoder, DPRQuestionEncoder, DPRReader
|
||||||
|
|
||||||
|
|
||||||
CheckpointState = collections.namedtuple(
|
CheckpointState = collections.namedtuple(
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import ElectraConfig, ElectraForMaskedLM, ElectraForPreTraining, load_tf_weights_in_electra
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -31,9 +31,10 @@ import torch
|
|||||||
from fairseq import hub_utils
|
from fairseq import hub_utils
|
||||||
from fairseq.data.dictionary import Dictionary
|
from fairseq.data.dictionary import Dictionary
|
||||||
|
|
||||||
from transformers import WEIGHTS_NAME, logging
|
from ...file_utils import WEIGHTS_NAME
|
||||||
from transformers.models.fsmt import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
||||||
from transformers.tokenization_utils_base import TOKENIZER_CONFIG_FILE
|
from ...utils import logging
|
||||||
|
from . import VOCAB_FILES_NAMES, FSMTConfig, FSMTForConditionalGeneration
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_warning()
|
logging.set_verbosity_warning()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
from . import FunnelConfig, FunnelForPreTraining, load_tf_weights_in_funnel
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers.utils import logging
|
from ...utils import logging
|
||||||
|
from . import GPT2Config, GPT2Model, load_tf_weights_in_gpt2
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import argparse
|
|||||||
import pytorch_lightning as pl
|
import pytorch_lightning as pl
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import LongformerForQuestionAnswering, LongformerModel
|
from . import LongformerForQuestionAnswering, LongformerModel
|
||||||
|
|
||||||
|
|
||||||
class LightningModel(pl.LightningModule):
|
class LightningModel(pl.LightningModule):
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import logging
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
from . import LxmertConfig, LxmertForPreTraining, load_tf_weights_in_lxmert
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from transformers.models.marian.convert_marian_to_pytorch import (
|
from .convert_marian_to_pytorch import (
|
||||||
FRONT_MATTER_TEMPLATE,
|
FRONT_MATTER_TEMPLATE,
|
||||||
_parse_readme,
|
_parse_readme,
|
||||||
convert_all_sentencepiece_models,
|
convert_all_sentencepiece_models,
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import MarianConfig, MarianMTModel, MarianTokenizer
|
from ...hf_api import HfApi
|
||||||
from transformers.hf_api import HfApi
|
from . import MarianConfig, MarianMTModel, MarianTokenizer
|
||||||
|
|
||||||
|
|
||||||
def remove_suffix(text: str, suffix: str):
|
def remove_suffix(text: str, suffix: str):
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BartForConditionalGeneration, MBartConfig
|
from ..bart import BartForConditionalGeneration
|
||||||
|
|
||||||
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
from ..bart.convert_bart_original_pytorch_checkpoint_to_pytorch import remove_ignore_keys_
|
||||||
|
from . import MBartConfig
|
||||||
|
|
||||||
|
|
||||||
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
def convert_fairseq_mbart_checkpoint_from_disk(checkpoint_path, hf_config_path="facebook/mbart-large-en-ro"):
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import MobileBertConfig, MobileBertForPreTraining, load_tf_weights_in_mobilebert
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -19,8 +19,9 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME, OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers.utils import logging
|
from ...utils import logging
|
||||||
|
from . import OpenAIGPTConfig, OpenAIGPTModel, load_tf_weights_in_openai_gpt
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ import tensorflow as tf
|
|||||||
import torch
|
import torch
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from transformers import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
from . import PegasusConfig, PegasusForConditionalGeneration, PegasusTokenizer
|
||||||
from transformers.models.pegasus.configuration_pegasus import DEFAULTS, task_specific_params
|
from .configuration_pegasus import DEFAULTS, task_specific_params
|
||||||
|
|
||||||
|
|
||||||
PATTERNS = [
|
PATTERNS = [
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ import argparse
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
|
||||||
|
|
||||||
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
# transformers_old should correspond to branch `save_old_prophetnet_model_structure` here
|
||||||
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
# original prophetnet_checkpoints are saved under `patrickvonplaten/..._old` respectively
|
||||||
from transformers_old.modeling_prophetnet import (
|
from transformers_old.modeling_prophetnet import (
|
||||||
@@ -30,6 +28,8 @@ from transformers_old.modeling_xlm_prophetnet import (
|
|||||||
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
XLMProphetNetForConditionalGeneration as XLMProphetNetForConditionalGenerationOld,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from . import ProphetNetForConditionalGeneration, XLMProphetNetForConditionalGeneration, logging
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -21,8 +21,8 @@ import pickle
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import ReformerConfig, ReformerModelWithLMHead
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import ReformerConfig, ReformerModelWithLMHead
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -24,19 +24,9 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
|||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers.models.bert.modeling_bert import (
|
from ...models.bert.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||||
BertIntermediate,
|
from ...utils import logging
|
||||||
BertLayer,
|
from .modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
BertOutput,
|
|
||||||
BertSelfAttention,
|
|
||||||
BertSelfOutput,
|
|
||||||
)
|
|
||||||
from transformers.models.roberta.modeling_roberta import (
|
|
||||||
RobertaConfig,
|
|
||||||
RobertaForMaskedLM,
|
|
||||||
RobertaForSequenceClassification,
|
|
||||||
)
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||||
|
|||||||
@@ -17,8 +17,8 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from . import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -22,8 +22,6 @@ from typing import Tuple
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from transformers.modeling_tf_utils import TFWrappedEmbeddings
|
|
||||||
|
|
||||||
from ...activations_tf import get_tf_activation
|
from ...activations_tf import get_tf_activation
|
||||||
from ...file_utils import (
|
from ...file_utils import (
|
||||||
DUMMY_INPUTS,
|
DUMMY_INPUTS,
|
||||||
@@ -42,6 +40,7 @@ from ...modeling_tf_utils import (
|
|||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
|
TFWrappedEmbeddings,
|
||||||
input_processing,
|
input_processing,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
|
|||||||
@@ -17,16 +17,16 @@
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from transformers.models.tapas.modeling_tapas import (
|
from ...utils import logging
|
||||||
|
from . import (
|
||||||
TapasConfig,
|
TapasConfig,
|
||||||
TapasForMaskedLM,
|
TapasForMaskedLM,
|
||||||
TapasForQuestionAnswering,
|
TapasForQuestionAnswering,
|
||||||
TapasForSequenceClassification,
|
TapasForSequenceClassification,
|
||||||
TapasModel,
|
TapasModel,
|
||||||
|
TapasTokenizer,
|
||||||
load_tf_weights_in_tapas,
|
load_tf_weights_in_tapas,
|
||||||
)
|
)
|
||||||
from transformers.models.tapas.tokenization_tapas import TapasTokenizer
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -28,9 +28,7 @@ from typing import Callable, Dict, Generator, List, Optional, Text, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers import add_end_docstrings
|
from ...file_utils import add_end_docstrings, is_pandas_available
|
||||||
|
|
||||||
from ...file_utils import is_pandas_available
|
|
||||||
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
from ...tokenization_utils import PreTrainedTokenizer, _is_control, _is_punctuation, _is_whitespace
|
||||||
from ...tokenization_utils_base import (
|
from ...tokenization_utils_base import (
|
||||||
ENCODE_KWARGS_DOCSTRING,
|
ENCODE_KWARGS_DOCSTRING,
|
||||||
|
|||||||
@@ -22,16 +22,11 @@ import sys
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
import transformers.models.transfo_xl.tokenization_transfo_xl as data_utils
|
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers import (
|
from ...utils import logging
|
||||||
CONFIG_NAME,
|
from . import TransfoXLConfig, TransfoXLLMHeadModel, load_tf_weights_in_transfo_xl
|
||||||
WEIGHTS_NAME,
|
from . import tokenization_transfo_xl as data_utils
|
||||||
TransfoXLConfig,
|
from .tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
||||||
TransfoXLLMHeadModel,
|
|
||||||
load_tf_weights_in_transfo_xl,
|
|
||||||
)
|
|
||||||
from transformers.models.transfo_xl.tokenization_transfo_xl import CORPUS_NAME, VOCAB_FILES_NAMES
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import json
|
|||||||
import numpy
|
import numpy
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import CONFIG_NAME, WEIGHTS_NAME
|
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
from transformers.models.xlm.tokenization_xlm import VOCAB_FILES_NAMES
|
from ...utils import logging
|
||||||
from transformers.utils import logging
|
from .tokenization_xlm import VOCAB_FILES_NAMES
|
||||||
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
logging.set_verbosity_info()
|
||||||
|
|||||||
@@ -20,16 +20,15 @@ import os
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import (
|
from ...file_utils import CONFIG_NAME, WEIGHTS_NAME
|
||||||
CONFIG_NAME,
|
from ...utils import logging
|
||||||
WEIGHTS_NAME,
|
from . import (
|
||||||
XLNetConfig,
|
XLNetConfig,
|
||||||
XLNetForQuestionAnswering,
|
XLNetForQuestionAnswering,
|
||||||
XLNetForSequenceClassification,
|
XLNetForSequenceClassification,
|
||||||
XLNetLMHeadModel,
|
XLNetLMHeadModel,
|
||||||
load_tf_weights_in_xlnet,
|
load_tf_weights_in_xlnet,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
|
|
||||||
GLUE_TASKS_NUM_LABELS = {
|
GLUE_TASKS_NUM_LABELS = {
|
||||||
|
|||||||
@@ -24,71 +24,147 @@
|
|||||||
##
|
##
|
||||||
## Put '## COMMENT' to comment on the file.
|
## Put '## COMMENT' to comment on the file.
|
||||||
|
|
||||||
|
# To replace in: "src/transformers/__init__.py"
|
||||||
|
# Below: " # PyTorch models structure" if generating PyTorch
|
||||||
|
# Replace with:
|
||||||
|
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||||
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||||
|
[
|
||||||
|
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForTokenClassification",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}Layer",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}Model",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||||
|
"load_tf_weights_in_{{cookiecutter.lowercase_modelname}}",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
{% else %}
|
||||||
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||||
|
[
|
||||||
|
"{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
|
"{{cookiecutter.camelcase_modelname}}Model",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
{% endif -%}
|
||||||
|
# End.
|
||||||
|
|
||||||
|
# Below: " # TensorFlow models structure" if generating TensorFlow
|
||||||
|
# Replace with:
|
||||||
|
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||||
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||||
|
[
|
||||||
|
"TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForMaskedLM",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForCausalLM",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForTokenClassification",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}Layer",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}Model",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
{% else %}
|
||||||
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].extend(
|
||||||
|
[
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}Model",
|
||||||
|
"TF{{cookiecutter.camelcase_modelname}}PreTrainedModel",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
{% endif -%}
|
||||||
|
# End.
|
||||||
|
|
||||||
|
# Below: " # Fast tokenizers"
|
||||||
|
# Replace with:
|
||||||
|
_import_structure["models.{{cookiecutter.lowercase_modelname}}"].append("{{cookiecutter.camelcase_modelname}}TokenizerFast")
|
||||||
|
# End.
|
||||||
|
|
||||||
|
# Below: " # Models"
|
||||||
|
# Replace with:
|
||||||
|
"models.{{cookiecutter.lowercase_modelname}}": ["{{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP", "{{cookiecutter.camelcase_modelname}}Config", "{{cookiecutter.camelcase_modelname}}Tokenizer"],
|
||||||
|
# End.
|
||||||
|
|
||||||
# To replace in: "src/transformers/__init__.py"
|
# To replace in: "src/transformers/__init__.py"
|
||||||
# Below: "if is_torch_available():" if generating PyTorch
|
# Below: " if is_torch_available():" if generating PyTorch
|
||||||
# Replace with:
|
# Replace with:
|
||||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Layer,
|
{{cookiecutter.camelcase_modelname}}Layer,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||||
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
load_tf_weights_in_{{cookiecutter.lowercase_modelname}},
|
||||||
)
|
)
|
||||||
{% else %}
|
{% else %}
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||||
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
{{cookiecutter.camelcase_modelname}}Model,
|
{{cookiecutter.camelcase_modelname}}Model,
|
||||||
)
|
)
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
# Below: "if is_tf_available():" if generating TensorFlow
|
# Below: " if is_tf_available():" if generating TensorFlow
|
||||||
# Replace with:
|
# Replace with:
|
||||||
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
{% if cookiecutter.is_encoder_decoder_model == "False" %}
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||||
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
TF_{{cookiecutter.uppercase_modelname}}_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
TF{{cookiecutter.camelcase_modelname}}ForMaskedLM,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
TF{{cookiecutter.camelcase_modelname}}ForCausalLM,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
TF{{cookiecutter.camelcase_modelname}}ForMultipleChoice,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
TF{{cookiecutter.camelcase_modelname}}ForQuestionAnswering,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
TF{{cookiecutter.camelcase_modelname}}ForSequenceClassification,
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
TF{{cookiecutter.camelcase_modelname}}ForTokenClassification,
|
||||||
TF{{cookiecutter.camelcase_modelname}}Layer,
|
TF{{cookiecutter.camelcase_modelname}}Layer,
|
||||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||||
)
|
)
|
||||||
{% else %}
|
{% else %}
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import (
|
from .models.{{cookiecutter.lowercase_modelname}} import (
|
||||||
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
TF{{cookiecutter.camelcase_modelname}}ForConditionalGeneration,
|
||||||
TF{{cookiecutter.camelcase_modelname}}Model,
|
TF{{cookiecutter.camelcase_modelname}}Model,
|
||||||
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
TF{{cookiecutter.camelcase_modelname}}PreTrainedModel,
|
||||||
)
|
)
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
# Below: "if is_tokenizers_available():"
|
# Below: " if is_tokenizers_available():"
|
||||||
# Replace with:
|
# Replace with:
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.camelcase_modelname}}TokenizerFast
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
# Below: "from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
# Below: " from .models.albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig"
|
||||||
# Replace with:
|
# Replace with:
|
||||||
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
|
from .models.{{cookiecutter.lowercase_modelname}} import {{cookiecutter.uppercase_modelname}}_PRETRAINED_CONFIG_ARCHIVE_MAP, {{cookiecutter.camelcase_modelname}}Config, {{cookiecutter.camelcase_modelname}}Tokenizer
|
||||||
# End.
|
# End.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# To replace in: "src/transformers/models/__init__.py"
|
||||||
|
# Below: "from . import ("
|
||||||
|
# Replace with:
|
||||||
|
{{cookiecutter.lowercase_modelname}},
|
||||||
|
# End.
|
||||||
|
|
||||||
|
|
||||||
# To replace in: "src/transformers/models/auto/configuration_auto.py"
|
# To replace in: "src/transformers/models/auto/configuration_auto.py"
|
||||||
# Below: "# Add configs here"
|
# Below: "# Add configs here"
|
||||||
# Replace with:
|
# Replace with:
|
||||||
|
|||||||
@@ -23,237 +23,79 @@ import re
|
|||||||
PATH_TO_TRANSFORMERS = "src/transformers"
|
PATH_TO_TRANSFORMERS = "src/transformers"
|
||||||
|
|
||||||
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
_re_single_line_import = re.compile(r"\s+from\s+\S*\s+import\s+([^\(\s].*)\n")
|
||||||
|
_re_test_backend = re.compile(r"^\s+if\s+is\_([a-z]*)\_available\(\):\s*$")
|
||||||
|
|
||||||
|
|
||||||
|
BACKENDS = ["torch", "tf", "flax", "sentencepiece", "tokenizers"]
|
||||||
|
|
||||||
|
|
||||||
DUMMY_CONSTANT = """
|
DUMMY_CONSTANT = """
|
||||||
{0} = None
|
{0} = None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DUMMY_PT_PRETRAINED_CLASS = """
|
DUMMY_PRETRAINED_CLASS = """
|
||||||
class {0}:
|
class {0}:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_{1}(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(self, *args, **kwargs):
|
def from_pretrained(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_{1}(self)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DUMMY_PT_CLASS = """
|
DUMMY_CLASS = """
|
||||||
class {0}:
|
class {0}:
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
requires_pytorch(self)
|
requires_{1}(self)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
DUMMY_PT_FUNCTION = """
|
DUMMY_FUNCTION = """
|
||||||
def {0}(*args, **kwargs):
|
def {0}(*args, **kwargs):
|
||||||
requires_pytorch({0})
|
requires_{1}({0})
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
DUMMY_TF_PRETRAINED_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_tf(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, *args, **kwargs):
|
|
||||||
requires_tf(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_TF_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_tf(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_TF_FUNCTION = """
|
|
||||||
def {0}(*args, **kwargs):
|
|
||||||
requires_tf({0})
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DUMMY_FLAX_PRETRAINED_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_flax(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, *args, **kwargs):
|
|
||||||
requires_flax(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_FLAX_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_flax(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_FLAX_FUNCTION = """
|
|
||||||
def {0}(*args, **kwargs):
|
|
||||||
requires_flax({0})
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DUMMY_SENTENCEPIECE_PRETRAINED_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_sentencepiece(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, *args, **kwargs):
|
|
||||||
requires_sentencepiece(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_SENTENCEPIECE_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_sentencepiece(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_SENTENCEPIECE_FUNCTION = """
|
|
||||||
def {0}(*args, **kwargs):
|
|
||||||
requires_sentencepiece({0})
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
DUMMY_TOKENIZERS_PRETRAINED_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_tokenizers(self)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_pretrained(self, *args, **kwargs):
|
|
||||||
requires_tokenizers(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_TOKENIZERS_CLASS = """
|
|
||||||
class {0}:
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
requires_tokenizers(self)
|
|
||||||
"""
|
|
||||||
|
|
||||||
DUMMY_TOKENIZERS_FUNCTION = """
|
|
||||||
def {0}(*args, **kwargs):
|
|
||||||
requires_tokenizers({0})
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Map all these to dummy type
|
|
||||||
|
|
||||||
DUMMY_PRETRAINED_CLASS = {
|
|
||||||
"pt": DUMMY_PT_PRETRAINED_CLASS,
|
|
||||||
"tf": DUMMY_TF_PRETRAINED_CLASS,
|
|
||||||
"flax": DUMMY_FLAX_PRETRAINED_CLASS,
|
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_PRETRAINED_CLASS,
|
|
||||||
"tokenizers": DUMMY_TOKENIZERS_PRETRAINED_CLASS,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_CLASS = {
|
|
||||||
"pt": DUMMY_PT_CLASS,
|
|
||||||
"tf": DUMMY_TF_CLASS,
|
|
||||||
"flax": DUMMY_FLAX_CLASS,
|
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_CLASS,
|
|
||||||
"tokenizers": DUMMY_TOKENIZERS_CLASS,
|
|
||||||
}
|
|
||||||
|
|
||||||
DUMMY_FUNCTION = {
|
|
||||||
"pt": DUMMY_PT_FUNCTION,
|
|
||||||
"tf": DUMMY_TF_FUNCTION,
|
|
||||||
"flax": DUMMY_FLAX_FUNCTION,
|
|
||||||
"sentencepiece": DUMMY_SENTENCEPIECE_FUNCTION,
|
|
||||||
"tokenizers": DUMMY_TOKENIZERS_FUNCTION,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def read_init():
|
def read_init():
|
||||||
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
""" Read the init and extracts PyTorch, TensorFlow, SentencePiece and Tokenizers objects. """
|
||||||
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
with open(os.path.join(PATH_TO_TRANSFORMERS, "__init__.py"), "r", encoding="utf-8", newline="\n") as f:
|
||||||
lines = f.readlines()
|
lines = f.readlines()
|
||||||
|
|
||||||
|
# Get to the point we do the actual imports for type checking
|
||||||
line_index = 0
|
line_index = 0
|
||||||
# Find where the SentencePiece imports begin
|
while not lines[line_index].startswith("if TYPE_CHECKING"):
|
||||||
sentencepiece_objects = []
|
|
||||||
while not lines[line_index].startswith("if is_sentencepiece_available():"):
|
|
||||||
line_index += 1
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Until we unindent, add SentencePiece objects to the list
|
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
|
||||||
line = lines[line_index]
|
|
||||||
search = _re_single_line_import.search(line)
|
|
||||||
if search is not None:
|
|
||||||
sentencepiece_objects += search.groups()[0].split(", ")
|
|
||||||
elif line.startswith(" "):
|
|
||||||
sentencepiece_objects.append(line[8:-2])
|
|
||||||
line_index += 1
|
line_index += 1
|
||||||
|
|
||||||
# Find where the Tokenizers imports begin
|
backend_specific_objects = {}
|
||||||
tokenizers_objects = []
|
# Go through the end of the file
|
||||||
while not lines[line_index].startswith("if is_tokenizers_available():"):
|
while line_index < len(lines):
|
||||||
line_index += 1
|
# If the line is an if is_backemd_available, we grab all objects associated.
|
||||||
line_index += 1
|
if _re_test_backend.search(lines[line_index]) is not None:
|
||||||
|
backend = _re_test_backend.search(lines[line_index]).groups()[0]
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
# Until we unindent, add Tokenizers objects to the list
|
# Ignore if backend isn't tracked for dummies.
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
if backend not in BACKENDS:
|
||||||
line = lines[line_index]
|
continue
|
||||||
search = _re_single_line_import.search(line)
|
|
||||||
if search is not None:
|
|
||||||
tokenizers_objects += search.groups()[0].split(", ")
|
|
||||||
elif line.startswith(" "):
|
|
||||||
tokenizers_objects.append(line[8:-2])
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Find where the PyTorch imports begin
|
objects = []
|
||||||
pt_objects = []
|
# Until we unindent, add backend objects to the list
|
||||||
while not lines[line_index].startswith("if is_torch_available():"):
|
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" " * 8):
|
||||||
line_index += 1
|
line = lines[line_index]
|
||||||
line_index += 1
|
single_line_import_search = _re_single_line_import.search(line)
|
||||||
|
if single_line_import_search is not None:
|
||||||
|
objects.extend(single_line_import_search.groups()[0].split(", "))
|
||||||
|
elif line.startswith(" " * 12):
|
||||||
|
objects.append(line[12:-2])
|
||||||
|
line_index += 1
|
||||||
|
|
||||||
# Until we unindent, add PyTorch objects to the list
|
backend_specific_objects[backend] = objects
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
else:
|
||||||
line = lines[line_index]
|
line_index += 1
|
||||||
search = _re_single_line_import.search(line)
|
|
||||||
if search is not None:
|
|
||||||
pt_objects += search.groups()[0].split(", ")
|
|
||||||
elif line.startswith(" "):
|
|
||||||
pt_objects.append(line[8:-2])
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Find where the TF imports begin
|
return backend_specific_objects
|
||||||
tf_objects = []
|
|
||||||
while not lines[line_index].startswith("if is_tf_available():"):
|
|
||||||
line_index += 1
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Until we unindent, add PyTorch objects to the list
|
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
|
||||||
line = lines[line_index]
|
|
||||||
search = _re_single_line_import.search(line)
|
|
||||||
if search is not None:
|
|
||||||
tf_objects += search.groups()[0].split(", ")
|
|
||||||
elif line.startswith(" "):
|
|
||||||
tf_objects.append(line[8:-2])
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Find where the FLAX imports begin
|
|
||||||
flax_objects = []
|
|
||||||
while not lines[line_index].startswith("if is_flax_available():"):
|
|
||||||
line_index += 1
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
# Until we unindent, add PyTorch objects to the list
|
|
||||||
while len(lines[line_index]) <= 1 or lines[line_index].startswith(" "):
|
|
||||||
line = lines[line_index]
|
|
||||||
search = _re_single_line_import.search(line)
|
|
||||||
if search is not None:
|
|
||||||
flax_objects += search.groups()[0].split(", ")
|
|
||||||
elif line.startswith(" "):
|
|
||||||
flax_objects.append(line[8:-2])
|
|
||||||
line_index += 1
|
|
||||||
|
|
||||||
return sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_object(name, type="pt"):
|
def create_dummy_object(name, backend_name):
|
||||||
""" Create the code for the dummy object corresponding to `name`."""
|
""" Create the code for the dummy object corresponding to `name`."""
|
||||||
_pretrained = [
|
_pretrained = [
|
||||||
"Config" "ForCausalLM",
|
"Config" "ForCausalLM",
|
||||||
@@ -266,11 +108,10 @@ def create_dummy_object(name, type="pt"):
|
|||||||
"Model",
|
"Model",
|
||||||
"Tokenizer",
|
"Tokenizer",
|
||||||
]
|
]
|
||||||
assert type in ["pt", "tf", "sentencepiece", "tokenizers", "flax"]
|
|
||||||
if name.isupper():
|
if name.isupper():
|
||||||
return DUMMY_CONSTANT.format(name)
|
return DUMMY_CONSTANT.format(name)
|
||||||
elif name.islower():
|
elif name.islower():
|
||||||
return (DUMMY_FUNCTION[type]).format(name)
|
return DUMMY_FUNCTION.format(name, backend_name)
|
||||||
else:
|
else:
|
||||||
is_pretrained = False
|
is_pretrained = False
|
||||||
for part in _pretrained:
|
for part in _pretrained:
|
||||||
@@ -278,114 +119,61 @@ def create_dummy_object(name, type="pt"):
|
|||||||
is_pretrained = True
|
is_pretrained = True
|
||||||
break
|
break
|
||||||
if is_pretrained:
|
if is_pretrained:
|
||||||
template = DUMMY_PRETRAINED_CLASS[type]
|
return DUMMY_PRETRAINED_CLASS.format(name, backend_name)
|
||||||
else:
|
else:
|
||||||
template = DUMMY_CLASS[type]
|
return DUMMY_CLASS.format(name, backend_name)
|
||||||
return template.format(name)
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_files():
|
def create_dummy_files():
|
||||||
""" Create the content of the dummy files. """
|
""" Create the content of the dummy files. """
|
||||||
sentencepiece_objects, tokenizers_objects, pt_objects, tf_objects, flax_objects = read_init()
|
backend_specific_objects = read_init()
|
||||||
|
# For special correspondence backend to module name as used in the function requires_modulename
|
||||||
|
module_names = {"torch": "pytorch"}
|
||||||
|
dummy_files = {}
|
||||||
|
|
||||||
sentencepiece_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
for backend, objects in backend_specific_objects.items():
|
||||||
sentencepiece_dummies += "from ..file_utils import requires_sentencepiece\n\n"
|
backend_name = module_names.get(backend, backend)
|
||||||
sentencepiece_dummies += "\n".join([create_dummy_object(o, type="sentencepiece") for o in sentencepiece_objects])
|
dummy_file = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
||||||
|
dummy_file += f"from ..file_utils import requires_{backend_name}\n\n"
|
||||||
|
dummy_file += "\n".join([create_dummy_object(o, backend_name) for o in objects])
|
||||||
|
dummy_files[backend] = dummy_file
|
||||||
|
|
||||||
tokenizers_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
return dummy_files
|
||||||
tokenizers_dummies += "from ..file_utils import requires_tokenizers\n\n"
|
|
||||||
tokenizers_dummies += "\n".join([create_dummy_object(o, type="tokenizers") for o in tokenizers_objects])
|
|
||||||
|
|
||||||
pt_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
|
||||||
pt_dummies += "from ..file_utils import requires_pytorch\n\n"
|
|
||||||
pt_dummies += "\n".join([create_dummy_object(o, type="pt") for o in pt_objects])
|
|
||||||
|
|
||||||
tf_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
|
||||||
tf_dummies += "from ..file_utils import requires_tf\n\n"
|
|
||||||
tf_dummies += "\n".join([create_dummy_object(o, type="tf") for o in tf_objects])
|
|
||||||
|
|
||||||
flax_dummies = "# This file is autogenerated by the command `make fix-copies`, do not edit.\n"
|
|
||||||
flax_dummies += "from ..file_utils import requires_flax\n\n"
|
|
||||||
flax_dummies += "\n".join([create_dummy_object(o, type="flax") for o in flax_objects])
|
|
||||||
|
|
||||||
return sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies
|
|
||||||
|
|
||||||
|
|
||||||
def check_dummies(overwrite=False):
|
def check_dummies(overwrite=False):
|
||||||
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
""" Check if the dummy files are up to date and maybe `overwrite` with the right content. """
|
||||||
sentencepiece_dummies, tokenizers_dummies, pt_dummies, tf_dummies, flax_dummies = create_dummy_files()
|
dummy_files = create_dummy_files()
|
||||||
|
# For special correspondence backend to shortcut as used in utils/dummy_xxx_objects.py
|
||||||
|
short_names = {"torch": "pt"}
|
||||||
|
|
||||||
|
# Locate actual dummy modules and read their content.
|
||||||
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
path = os.path.join(PATH_TO_TRANSFORMERS, "utils")
|
||||||
sentencepiece_file = os.path.join(path, "dummy_sentencepiece_objects.py")
|
dummy_file_paths = {
|
||||||
tokenizers_file = os.path.join(path, "dummy_tokenizers_objects.py")
|
backend: os.path.join(path, f"dummy_{short_names.get(backend, backend)}_objects.py")
|
||||||
pt_file = os.path.join(path, "dummy_pt_objects.py")
|
for backend in dummy_files.keys()
|
||||||
tf_file = os.path.join(path, "dummy_tf_objects.py")
|
}
|
||||||
flax_file = os.path.join(path, "dummy_flax_objects.py")
|
|
||||||
|
|
||||||
with open(sentencepiece_file, "r", encoding="utf-8", newline="\n") as f:
|
actual_dummies = {}
|
||||||
actual_sentencepiece_dummies = f.read()
|
for backend, file_path in dummy_file_paths.items():
|
||||||
with open(tokenizers_file, "r", encoding="utf-8", newline="\n") as f:
|
with open(file_path, "r", encoding="utf-8", newline="\n") as f:
|
||||||
actual_tokenizers_dummies = f.read()
|
actual_dummies[backend] = f.read()
|
||||||
with open(pt_file, "r", encoding="utf-8", newline="\n") as f:
|
|
||||||
actual_pt_dummies = f.read()
|
|
||||||
with open(tf_file, "r", encoding="utf-8", newline="\n") as f:
|
|
||||||
actual_tf_dummies = f.read()
|
|
||||||
with open(flax_file, "r", encoding="utf-8", newline="\n") as f:
|
|
||||||
actual_flax_dummies = f.read()
|
|
||||||
|
|
||||||
if sentencepiece_dummies != actual_sentencepiece_dummies:
|
for backend in dummy_files.keys():
|
||||||
if overwrite:
|
if dummy_files[backend] != actual_dummies[backend]:
|
||||||
print("Updating transformers.utils.dummy_sentencepiece_objects.py as the main __init__ has new objects.")
|
if overwrite:
|
||||||
with open(sentencepiece_file, "w", encoding="utf-8", newline="\n") as f:
|
print(
|
||||||
f.write(sentencepiece_dummies)
|
f"Updating transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py as the main "
|
||||||
else:
|
"__init__ has new objects."
|
||||||
raise ValueError(
|
)
|
||||||
"The main __init__ has objects that are not present in transformers.utils.dummy_sentencepiece_objects.py.",
|
with open(dummy_file_paths[backend], "w", encoding="utf-8", newline="\n") as f:
|
||||||
"Run `make fix-copies` to fix this.",
|
f.write(dummy_files[backend])
|
||||||
)
|
else:
|
||||||
|
raise ValueError(
|
||||||
if tokenizers_dummies != actual_tokenizers_dummies:
|
"The main __init__ has objects that are not present in "
|
||||||
if overwrite:
|
f"transformers.utils.dummy_{short_names.get(backend, backend)}_objects.py. Run `make fix-copies` "
|
||||||
print("Updating transformers.utils.dummy_tokenizers_objects.py as the main __init__ has new objects.")
|
"to fix this."
|
||||||
with open(tokenizers_file, "w", encoding="utf-8", newline="\n") as f:
|
)
|
||||||
f.write(tokenizers_dummies)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The main __init__ has objects that are not present in transformers.utils.dummy_tokenizers_objects.py.",
|
|
||||||
"Run `make fix-copies` to fix this.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if pt_dummies != actual_pt_dummies:
|
|
||||||
if overwrite:
|
|
||||||
print("Updating transformers.utils.dummy_pt_objects.py as the main __init__ has new objects.")
|
|
||||||
with open(pt_file, "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(pt_dummies)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
|
||||||
"Run `make fix-copies` to fix this.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if tf_dummies != actual_tf_dummies:
|
|
||||||
if overwrite:
|
|
||||||
print("Updating transformers.utils.dummy_tf_objects.py as the main __init__ has new objects.")
|
|
||||||
with open(tf_file, "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(tf_dummies)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The main __init__ has objects that are not present in transformers.utils.dummy_pt_objects.py.",
|
|
||||||
"Run `make fix-copies` to fix this.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if flax_dummies != actual_flax_dummies:
|
|
||||||
if overwrite:
|
|
||||||
print("Updating transformers.utils.dummy_flax_objects.py as the main __init__ has new objects.")
|
|
||||||
with open(flax_file, "w", encoding="utf-8", newline="\n") as f:
|
|
||||||
f.write(flax_dummies)
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
"The main __init__ has objects that are not present in transformers.utils.dummy_flax_objects.py.",
|
|
||||||
"Run `make fix-copies` to fix this.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -413,9 +413,6 @@ SHOULD_HAVE_THEIR_OWN_PAGE = [
|
|||||||
def ignore_undocumented(name):
|
def ignore_undocumented(name):
|
||||||
"""Rules to determine if `name` should be undocumented."""
|
"""Rules to determine if `name` should be undocumented."""
|
||||||
# NOT DOCUMENTED ON PURPOSE.
|
# NOT DOCUMENTED ON PURPOSE.
|
||||||
# Magic attributes are not documented.
|
|
||||||
if name.startswith("__"):
|
|
||||||
return True
|
|
||||||
# Constants uppercase are not documented.
|
# Constants uppercase are not documented.
|
||||||
if name.isupper():
|
if name.isupper():
|
||||||
return True
|
return True
|
||||||
@@ -459,7 +456,9 @@ def ignore_undocumented(name):
|
|||||||
def check_all_objects_are_documented():
|
def check_all_objects_are_documented():
|
||||||
""" Check all models are properly documented."""
|
""" Check all models are properly documented."""
|
||||||
documented_objs = find_all_documented_objects()
|
documented_objs = find_all_documented_objects()
|
||||||
undocumented_objs = [c for c in dir(transformers) if c not in documented_objs and not ignore_undocumented(c)]
|
modules = transformers._modules
|
||||||
|
objects = [c for c in dir(transformers) if c not in modules and not c.startswith("_")]
|
||||||
|
undocumented_objs = [c for c in objects if c not in documented_objs and not ignore_undocumented(c)]
|
||||||
if len(undocumented_objs) > 0:
|
if len(undocumented_objs) > 0:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"The following objects are in the public init so should be documented:\n - "
|
"The following objects are in the public init so should be documented:\n - "
|
||||||
|
|||||||
Reference in New Issue
Block a user