Transformers fast import part 2 (#9446)
* Main init work * Add version * Change from absolute to relative imports * Fix imports * One more typo * More typos * Styling * Make quality script pass * Add necessary replace in template * Fix typos * Spaces are ignored in replace for some reason * Forgot one models. * Fixes for import Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr> * Add documentation * Styling Co-authored-by: LysandreJik <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -19,9 +19,8 @@ from argparse import ArgumentParser, Namespace
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -14,9 +14,8 @@
|
||||
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def convert_command_factory(args: Namespace):
|
||||
@@ -87,7 +86,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
def run(self):
|
||||
if self._model_type == "albert":
|
||||
try:
|
||||
from transformers.models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.albert.convert_albert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -96,7 +95,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "bert":
|
||||
try:
|
||||
from transformers.models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.bert.convert_bert_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -105,7 +104,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "funnel":
|
||||
try:
|
||||
from transformers.models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.funnel.convert_funnel_original_tf_checkpoint_to_pytorch import (
|
||||
convert_tf_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -113,14 +112,14 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
|
||||
convert_tf_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "gpt":
|
||||
from transformers.models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.openai.convert_openai_original_tf_checkpoint_to_pytorch import (
|
||||
convert_openai_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_openai_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "transfo_xl":
|
||||
try:
|
||||
from transformers.models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.transfo_xl.convert_transfo_xl_original_tf_checkpoint_to_pytorch import (
|
||||
convert_transfo_xl_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -137,7 +136,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
)
|
||||
elif self._model_type == "gpt2":
|
||||
try:
|
||||
from transformers.models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.gpt2.convert_gpt2_original_tf_checkpoint_to_pytorch import (
|
||||
convert_gpt2_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -146,7 +145,7 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
convert_gpt2_checkpoint_to_pytorch(self._tf_checkpoint, self._config, self._pytorch_dump_output)
|
||||
elif self._model_type == "xlnet":
|
||||
try:
|
||||
from transformers.models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
from ..models.xlnet.convert_xlnet_original_tf_checkpoint_to_pytorch import (
|
||||
convert_xlnet_checkpoint_to_pytorch,
|
||||
)
|
||||
except ImportError:
|
||||
@@ -156,13 +155,13 @@ class ConvertCommand(BaseTransformersCLICommand):
|
||||
self._tf_checkpoint, self._config, self._pytorch_dump_output, self._finetuning_task_name
|
||||
)
|
||||
elif self._model_type == "xlm":
|
||||
from transformers.models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
from ..models.xlm.convert_xlm_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_xlm_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
convert_xlm_checkpoint_to_pytorch(self._tf_checkpoint, self._pytorch_dump_output)
|
||||
elif self._model_type == "lxmert":
|
||||
from transformers.models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
from ..models.lxmert.convert_lxmert_original_pytorch_checkpoint_to_pytorch import (
|
||||
convert_lxmert_checkpoint_to_pytorch,
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def download_command_factory(args):
|
||||
@@ -40,7 +40,7 @@ class DownloadCommand(BaseTransformersCLICommand):
|
||||
self._force = force
|
||||
|
||||
def run(self):
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
from ..models.auto import AutoModel, AutoTokenizer
|
||||
|
||||
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
import platform
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers import __version__ as version
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from .. import __version__ as version
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def info_command_factory(_):
|
||||
|
||||
@@ -25,9 +25,9 @@ from contextlib import AbstractContextManager
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import requests
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -14,10 +14,9 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, PipelineDataFormat, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@@ -15,11 +15,9 @@
|
||||
from argparse import ArgumentParser, Namespace
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from transformers import Pipeline
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.pipelines import SUPPORTED_TASKS, pipeline
|
||||
|
||||
from ..pipelines import SUPPORTED_TASKS, Pipeline, pipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -15,11 +15,11 @@
|
||||
import os
|
||||
from argparse import ArgumentParser, Namespace
|
||||
|
||||
from transformers import SingleSentenceClassificationProcessor as Processor
|
||||
from transformers import TextClassificationPipeline, is_tf_available, is_torch_available
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
|
||||
from ..data import SingleSentenceClassificationProcessor as Processor
|
||||
from ..file_utils import is_tf_available, is_torch_available
|
||||
from ..pipelines import TextClassificationPipeline
|
||||
from ..utils import logging
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
if not is_tf_available() and not is_torch_available():
|
||||
|
||||
@@ -15,14 +15,14 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
from transformers.commands.add_new_model import AddNewModelCommand
|
||||
from transformers.commands.convert import ConvertCommand
|
||||
from transformers.commands.download import DownloadCommand
|
||||
from transformers.commands.env import EnvironmentCommand
|
||||
from transformers.commands.lfs import LfsCommands
|
||||
from transformers.commands.run import RunCommand
|
||||
from transformers.commands.serving import ServeCommand
|
||||
from transformers.commands.user import UserCommands
|
||||
from .add_new_model import AddNewModelCommand
|
||||
from .convert import ConvertCommand
|
||||
from .download import DownloadCommand
|
||||
from .env import EnvironmentCommand
|
||||
from .lfs import LfsCommands
|
||||
from .run import RunCommand
|
||||
from .serving import ServeCommand
|
||||
from .user import UserCommands
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -20,8 +20,9 @@ from getpass import getpass
|
||||
from typing import List, Union
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers.commands import BaseTransformersCLICommand
|
||||
from transformers.hf_api import HfApi, HfFolder
|
||||
|
||||
from ..hf_api import HfApi, HfFolder
|
||||
from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
UPLOAD_MAX_FILES = 15
|
||||
|
||||
Reference in New Issue
Block a user