add trust_remote_code option to CLI download cmd (#24097)
* add trust_remote_code option * require_torch
This commit is contained in:
@@ -18,7 +18,7 @@ from . import BaseTransformersCLICommand
|
||||
|
||||
|
||||
def download_command_factory(args):
|
||||
return DownloadCommand(args.model, args.cache_dir, args.force)
|
||||
return DownloadCommand(args.model, args.cache_dir, args.force, args.trust_remote_code)
|
||||
|
||||
|
||||
class DownloadCommand(BaseTransformersCLICommand):
|
||||
@@ -31,16 +31,26 @@ class DownloadCommand(BaseTransformersCLICommand):
|
||||
download_parser.add_argument(
|
||||
"--force", action="store_true", help="Force the model to be download even if already in cache-dir"
|
||||
)
|
||||
download_parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Whether or not to allow for custom models defined on the Hub in their own modeling files. Use only if you've reviewed the code as it will execute on your local machine",
|
||||
)
|
||||
download_parser.add_argument("model", type=str, help="Name of the model to download")
|
||||
download_parser.set_defaults(func=download_command_factory)
|
||||
|
||||
def __init__(self, model: str, cache: str, force: bool):
|
||||
def __init__(self, model: str, cache: str, force: bool, trust_remote_code: bool):
|
||||
self._model = model
|
||||
self._cache = cache
|
||||
self._force = force
|
||||
self._trust_remote_code = trust_remote_code
|
||||
|
||||
def run(self):
|
||||
from ..models.auto import AutoModel, AutoTokenizer
|
||||
|
||||
AutoModel.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
AutoTokenizer.from_pretrained(self._model, cache_dir=self._cache, force_download=self._force)
|
||||
AutoModel.from_pretrained(
|
||||
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
|
||||
)
|
||||
AutoTokenizer.from_pretrained(
|
||||
self._model, cache_dir=self._cache, force_download=self._force, trust_remote_code=self._trust_remote_code
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import shutil
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test
|
||||
from transformers.testing_utils import CaptureStd, is_pt_tf_cross_test, require_torch
|
||||
|
||||
|
||||
class CLITest(unittest.TestCase):
|
||||
@@ -45,3 +45,47 @@ class CLITest(unittest.TestCase):
|
||||
|
||||
# The original repo has no TF weights -- if they exist, they were created by the CLI
|
||||
self.assertTrue(os.path.exists("/tmp/hf-internal-testing/tiny-random-gptj/tf_model.h5"))
|
||||
|
||||
@require_torch
|
||||
@patch("sys.argv", ["fakeprogrampath", "download", "hf-internal-testing/tiny-random-gptj", "--cache-dir", "/tmp"])
|
||||
def test_cli_download(self):
|
||||
import transformers.commands.transformers_cli
|
||||
|
||||
# # remove any previously downloaded model to start clean
|
||||
shutil.rmtree("/tmp/models--hf-internal-testing--tiny-random-gptj", ignore_errors=True)
|
||||
|
||||
# run the command
|
||||
transformers.commands.transformers_cli.main()
|
||||
|
||||
# check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--tiny-random-gptj
|
||||
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/blobs"))
|
||||
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/refs"))
|
||||
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--tiny-random-gptj/snapshots"))
|
||||
|
||||
@require_torch
|
||||
@patch(
|
||||
"sys.argv",
|
||||
[
|
||||
"fakeprogrampath",
|
||||
"download",
|
||||
"hf-internal-testing/test_dynamic_model_with_tokenizer",
|
||||
"--trust-remote-code",
|
||||
"--cache-dir",
|
||||
"/tmp",
|
||||
],
|
||||
)
|
||||
def test_cli_download_trust_remote(self):
|
||||
import transformers.commands.transformers_cli
|
||||
|
||||
# # remove any previously downloaded model to start clean
|
||||
shutil.rmtree("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer", ignore_errors=True)
|
||||
|
||||
# run the command
|
||||
transformers.commands.transformers_cli.main()
|
||||
|
||||
# check if the model files are downloaded correctly on /tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer
|
||||
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/blobs"))
|
||||
self.assertTrue(os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/refs"))
|
||||
self.assertTrue(
|
||||
os.path.exists("/tmp/models--hf-internal-testing--test_dynamic_model_with_tokenizer/snapshots")
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user