From a73883ae9ec66cb35a8222f204a5f2fafc326d3f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Radam=C3=A9s=20Ajna?= Date: Thu, 8 Jun 2023 08:13:57 -0700 Subject: [PATCH] add trust_remote_code option to CLI download cmd (#24097) * add trust_remote_code option * require_torch --- src/transformers/commands/download.py | 18 ++++++++--- tests/utils/test_cli.py | 46 ++++++++++++++++++++++++++- 2 files changed, 59 insertions(+), 5 deletions(-) diff --git a/src/transformers/commands/download.py b/src/transformers/commands/download.py index 3c224555df..8af3c6397b 100644 --- a/src/transformers/commands/download.py +++ b/src/transformers/commands/download.py @@ -18,7 +18,7 @@ from . import BaseTransformersCLICommand def download_command_factory(args): - return DownloadCommand(args.model, args.cache_dir, args.force) + return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code) class DownloadCommand(BaseTransformersCLICommand): @@ -31,16 +31,26 @@ class DownloadCommand(BaseTransformersCLICommand): download_parser.add_argument( "--force", action="store_true", help="Force the model to be download even if already in cache-dir" ) + download_parser.add_argument( + "--trust-remote-code", + action="store_true", + help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine", + ) download_parser.add_argument("model", type=str, help="Name of the model to download") download_parser.set_defaults(func=download_command_factory) - def __init__(self, model: str, cache: str, force: bool): + def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool): self._model = model self._cache = cache self._force = force + self._trust_remote_code = trust_remote_code def run(self): from ..models.auto import AutoModel, AutoTokenizer - AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) - AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force) + AutoModel.from_pretrained( + self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code + ) + AutoTokenizer.from_pretrained( + self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code + ) diff --git a/tests/utils/test_cli.py b/tests/utils/test_cli.py index f39aa60067..fc7b8ebb5e 100644 --- a/tests/utils/test_cli.py +++ b/tests/utils/test_cli.py @@ -18,7 +18,7 @@ import shutil import unittest from unittest.mock import patch -from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test +from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test, require_torch class CLITest(unittest.TestCase): @@ -45,3 +45,47 @@ class CLITest(unittest.TestCase): # The original repo has no TF weights -- if they exist, they were created by the CLI self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5")) + + @require_torch + @patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"]) + def test_cli_download(self): + import transformers.commands.transformers_cli + + # # remove any previously downloaded model to start clean + shutil.rmtree("/tmp/models--hf-internal-testing--tiny-random-gptj", ignore_errors=True) + + # run the command + transformers.commands.transformers_cli.main() + + # check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--tiny-random-gptj + self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/blobs")) + self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/refs")) + self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/snapshots")) + + @require_torch + @patch( + "sys.argv", + [ + "fakeprogrampath", + "download", + "hf-internal-testing/test_dynamic_model_with_tokenizer", + "--trust-remote-code", + "--cache-dir", + "/tmp", + ], + ) + def test_cli_download_trust_remote(self): + import transformers.commands.transformers_cli + + # # remove any previously downloaded model to start clean + shutil.rmtree("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer", ignore_errors=True) + + # run the command + transformers.commands.transformers_cli.main() + + # check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer + self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/blobs")) + self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/refs")) + self.assertTrue( + os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/snapshots") + )