Added download command through the cli.

It allows to predownload models and tokenizers.
This commit is contained in:
Morgan Funtowicz
2019-12-03 14:56:57 +01:00
parent 31a3a73ee3
commit 81babb227e
2 changed files with 32 additions and 1 deletions

4
transformers-cli Normal file → Executable file
View File

@@ -1,6 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
from argparse import ArgumentParser from argparse import ArgumentParser
from transformers.commands.download import DownloadCommand
from transformers.commands.serving import ServeCommand from transformers.commands.serving import ServeCommand
from transformers.commands.user import UserCommands from transformers.commands.user import UserCommands
from transformers.commands.train import TrainCommand from transformers.commands.train import TrainCommand
@@ -11,10 +12,11 @@ if __name__ == '__main__':
commands_parser = parser.add_subparsers(help='transformers-cli command helpers') commands_parser = parser.add_subparsers(help='transformers-cli command helpers')
# Register commands # Register commands
ConvertCommand.register_subcommand(commands_parser)
DownloadCommand.register_subcommand(commands_parser)
ServeCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser)
UserCommands.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser)
TrainCommand.register_subcommand(commands_parser) TrainCommand.register_subcommand(commands_parser)
ConvertCommand.register_subcommand(commands_parser)
# Let's go # Let's go
args = parser.parse_args() args = parser.parse_args()

View File

@@ -0,0 +1,29 @@
from argparse import ArgumentParser
from transformers.commands import BaseTransformersCLICommand
def download_command_factory(args):
return DownloadCommand(args.model, args.cache_dir, args.force)
class DownloadCommand(BaseTransformersCLICommand):
@staticmethod
def register_subcommand(parser: ArgumentParser):
download_parser = parser.add_parser('download')
download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models')
download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir')
download_parser.add_argument('model', type=str, help='Name of the model to download')
download_parser.set_defaults(func=download_command_factory)
def __init__(self, model: str, cache: str, force: bool):
self._model = model
self._cache = cache
self._force = force
def run(self):
from transformers import AutoModel, AutoTokenizer
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)