Transformers fast import part 2 (#9446)

* Main init work

* Add version

* Change from absolute to relative imports

* Fix imports

* One more typo

* More typos

* Styling

* Make quality script pass

* Add necessary replace in template

* Fix typos

* Spaces are ignored in replace for some reason

* Forgot one models.

* Fixes for import

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>

* Add documentation

* Styling

Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
Sylvain Gugger
2021-01-07 09:36:14 -05:00
committed by GitHub
parent a400fe8931
commit 758ed3332b
56 changed files with 2426 additions and 1377 deletions

View File

@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
from pathlib import Path
from typing import List
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
try:

View File

@@ -14,9 +14,8 @@
from argparse import ArgumentParser, Namespace
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
def convert_command_factory(args: Namespace):
@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
def run(self):
if self._model_type == "albert":
try:
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "bert":
try:
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "funnel":
try:
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
convert_tf_checkpoint_to_pytorch,
)
except ImportError:
@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "gpt":
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
convert_openai_checkpoint_to_pytorch,
)
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "transfo_xl":
try:
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
convert_transfo_xl_checkpoint_to_pytorch,
)
except ImportError:
@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
)
elif self._model_type == "gpt2":
try:
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
convert_gpt2_checkpoint_to_pytorch,
)
except ImportError:
@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
elif self._model_type == "xlnet":
try:
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
convert_xlnet_checkpoint_to_pytorch,
)
except ImportError:
@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
)
elif self._model_type == "xlm":
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
convert_xlm_checkpoint_to_pytorch,
)
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
elif self._model_type == "lxmert":
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
convert_lxmert_checkpoint_to_pytorch,
)

View File

@@ -14,7 +14,7 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from . import BaseTransformersCLICommand
def download_command_factory(args):
@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
from ..models.auto import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)

View File

@@ -15,9 +15,9 @@
import platform
from argparse import ArgumentParser
from transformers import __version__ as version
from transformers import is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from .. import __version__ as version
from ..file_utils import is_tf_available, is_torch_available
from . import BaseTransformersCLICommand
def info_command_factory(_):

View File

@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
from typing import Dict, List, Optional
import requests
from transformers.commands import BaseTransformersCLICommand
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -14,10 +14,9 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -15,11 +15,9 @@
from argparse import ArgumentParser, Namespace
from typing import Any, List, Optional
from transformers import Pipeline
from transformers.commands import BaseTransformersCLICommand
from transformers.pipelines import SUPPORTED_TASKS, pipeline
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
from ..utils import logging
from . import BaseTransformersCLICommand
try:

View File

@@ -15,11 +15,11 @@
import os
from argparse import ArgumentParser, Namespace
from transformers import SingleSentenceClassificationProcessor as Processor
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
from transformers.commands import BaseTransformersCLICommand
from ..data import SingleSentenceClassificationProcessor as Processor
from ..file_utils import is_tf_available, is_torch_available
from ..pipelines import TextClassificationPipeline
from ..utils import logging
from . import BaseTransformersCLICommand
if not is_tf_available() and not is_torch_available():

View File

@@ -15,14 +15,14 @@
from argparse import ArgumentParser
from transformers.commands.add_new_model import AddNewModelCommand
from transformers.commands.convert import ConvertCommand
from transformers.commands.download import DownloadCommand
from transformers.commands.env import EnvironmentCommand
from transformers.commands.lfs import LfsCommands
from transformers.commands.run import RunCommand
from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands
from .add_new_model import AddNewModelCommand
from .convert import ConvertCommand
from .download import DownloadCommand
from .env import EnvironmentCommand
from .lfs import LfsCommands
from .run import RunCommand
from .serving import ServeCommand
from .user import UserCommands
def main():

View File

@@ -20,8 +20,9 @@ from getpass import getpass
from typing import List, Union
from requests.exceptions import HTTPError
from transformers.commands import BaseTransformersCLICommand
from transformers.hf_api import HfApi, HfFolder
from ..hf_api import HfApi, HfFolder
from . import BaseTransformersCLICommand
UPLOAD_MAX_FILES = 15