From 81babb227e6d6505be088ac452f3cda8a14c2255 Mon Sep 17 00:00:00 2001 From: Morgan Funtowicz Date: Tue, 3 Dec 2019 14:56:57 +0100 Subject: [PATCH] Added download command through the cli. It allows to predownload models and tokenizers. --- transformers-cli | 4 +++- transformers/commands/download.py | 29 +++++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) mode change 100644 => 100755 transformers-cli create mode 100644 transformers/commands/download.py diff --git a/transformers-cli b/transformers-cli old mode 100644 new mode 100755 index 397b382308..168e6e6f32 --- a/transformers-cli +++ b/transformers-cli @@ -1,6 +1,7 @@ #!/usr/bin/env python from argparse import ArgumentParser +from transformers.commands.download import DownloadCommand from transformers.commands.serving import ServeCommand from transformers.commands.user import UserCommands from transformers.commands.train import TrainCommand @@ -11,10 +12,11 @@ if __name__ == '__main__': commands_parser = parser.add_subparsers(help='transformers-cli command helpers') # Register commands + ConvertCommand.register_subcommand(commands_parser) + DownloadCommand.register_subcommand(commands_parser) ServeCommand.register_subcommand(commands_parser) UserCommands.register_subcommand(commands_parser) TrainCommand.register_subcommand(commands_parser) - ConvertCommand.register_subcommand(commands_parser) # Let's go args = parser.parse_args() diff --git a/transformers/commands/download.py b/transformers/commands/download.py new file mode 100644 index 0000000000..0938f135d2 --- /dev/null +++ b/transformers/commands/download.py @@ -0,0 +1,29 @@ +from argparse import ArgumentParser + +from transformers.commands import BaseTransformersCLICommand + + +def download_command_factory(args): + return DownloadCommand(args.model, args.cache_dir, args.force) + + +class DownloadCommand(BaseTransformersCLICommand): + + @staticmethod + def register_subcommand(parser: ArgumentParser): + download_parser = parser.add_parser('download') + download_parser.add_argument('--cache-dir', type=str, default=None, help='Path to location to store the models') + download_parser.add_argument('--force', action='store_true', help='Force the model to be download even if already in cache-dir') + download_parser.add_argument('model', type=str, help='Name of the model to download') + download_parser.set_defaults(func=download_command_factory) + + def __init__(self, model: str, cache: str, force: bool): + self._model = model + self._cache = cache + self._force = force + + def run(self): + from transformers import AutoModel, AutoTokenizer + + AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) + AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) \ No newline at end of file