From ca485e562b675341409e3e27724072fb11e10af7 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Fri, 16 Sep 2022 17:20:02 -0400 Subject: [PATCH] Add tests for legacy load by url and fix bugs (#19078) --- src/transformers/modeling_flax_utils.py | 2 +- src/transformers/modeling_tf_utils.py | 2 +- src/transformers/modeling_utils.py | 2 +- src/transformers/tokenization_utils_base.py | 2 +- tests/test_configuration_common.py | 6 ++++++ tests/test_feature_extraction_common.py | 6 ++++++ tests/test_modeling_common.py | 21 +++++++++++++++++++++ tests/test_modeling_tf_common.py | 19 +++++++++++++++++++ tests/test_tokenization_common.py | 7 ++++++- 9 files changed, 62 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_flax_utils.py b/src/transformers/modeling_flax_utils.py index 92d307e8cd..3299b543b7 100644 --- a/src/transformers/modeling_flax_utils.py +++ b/src/transformers/modeling_flax_utils.py @@ -680,7 +680,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin): archive_file = pretrained_model_name_or_path is_local = True elif is_remote_url(pretrained_model_name_or_path): - archive_file = pretrained_model_name_or_path + filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: filename = WEIGHTS_NAME if from_pt else FLAX_WEIGHTS_NAME diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index a90d1f0ebe..af4eab5908 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -2418,7 +2418,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu archive_file = pretrained_model_name_or_path + ".index" is_local = True elif is_remote_url(pretrained_model_name_or_path): - archive_file = pretrained_model_name_or_path + filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: # set correct filename diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index af32c3f98f..79a8542d8b 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2005,7 +2005,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index") is_local = True elif is_remote_url(pretrained_model_name_or_path): - archive_file = pretrained_model_name_or_path + filename = pretrained_model_name_or_path resolved_archive_file = download_url(pretrained_model_name_or_path) else: # set correct filename diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2e7ac0be0f..54d562136d 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1670,7 +1670,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): init_configuration = {} is_local = os.path.isdir(pretrained_model_name_or_path) - if os.path.isfile(pretrained_model_name_or_path): + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): if len(cls.vocab_files_names) > 1: raise ValueError( f"Calling {cls.__name__}.from_pretrained() with the path to a single file or url is not " diff --git a/tests/test_configuration_common.py b/tests/test_configuration_common.py index a7283b5f31..c2d48ef662 100644 --- a/tests/test_configuration_common.py +++ b/tests/test_configuration_common.py @@ -360,6 +360,12 @@ class ConfigTestUtils(unittest.TestCase): # This check we did call the fake head request mock_head.assert_called() + def test_legacy_load_from_url(self): + # This test is for deprecated behavior and can be removed in v5 + _ = BertConfig.from_pretrained( + "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json" + ) + class ConfigurationVersioningTest(unittest.TestCase): def test_local_versioning(self): diff --git a/tests/test_feature_extraction_common.py b/tests/test_feature_extraction_common.py index 61bd85e892..7b7c33a964 100644 --- a/tests/test_feature_extraction_common.py +++ b/tests/test_feature_extraction_common.py @@ -182,6 +182,12 @@ class FeatureExtractorUtilTester(unittest.TestCase): # This check we did call the fake head request mock_head.assert_called() + def test_legacy_load_from_url(self): + # This test is for deprecated behavior and can be removed in v5 + _ = Wav2Vec2FeatureExtractor.from_pretrained( + "https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json" + ) + @is_staging_test class FeatureExtractorPushToHubTester(unittest.TestCase): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 6c4814c1a8..082f2a8a90 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -33,6 +33,7 @@ import numpy as np import transformers from huggingface_hub import HfFolder, delete_repo, set_access_token +from huggingface_hub.file_download import http_get from requests.exceptions import HTTPError from transformers import ( AutoConfig, @@ -2949,6 +2950,26 @@ class ModelUtilsTest(TestCasePlus): # This check we did call the fake head request mock_head.assert_called() + def test_load_from_one_file(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get( + "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f + ) + + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + _ = BertModel.from_pretrained(tmp_file, config=config) + finally: + os.remove(tmp_file) + + def test_legacy_load_from_url(self): + # This test is for deprecated behavior and can be removed in v5 + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + _ = BertModel.from_pretrained( + "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config + ) + @require_torch @is_staging_test diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 620d84083e..9977578b51 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints from datasets import Dataset from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token +from huggingface_hub.file_download import http_get from requests.exceptions import HTTPError from transformers import is_tf_available, is_torch_available from transformers.configuration_utils import PretrainedConfig @@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase): # This check we did call the fake head request mock_head.assert_called() + def test_load_from_one_file(self): + try: + tmp_file = tempfile.mktemp() + with open(tmp_file, "wb") as f: + http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f) + + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + _ = TFBertModel.from_pretrained(tmp_file, config=config) + finally: + os.remove(tmp_file) + + def test_legacy_load_from_url(self): + # This test is for deprecated behavior and can be removed in v5 + config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert") + _ = TFBertModel.from_pretrained( + "https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config + ) + # tests whether the unpack_inputs function behaves as expected def test_unpack_inputs(self): class DummyModel: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index ef6eb421b4..48add3f4f9 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -3891,15 +3891,20 @@ class TokenizerUtilTester(unittest.TestCase): mock_head.assert_called() def test_legacy_load_from_one_file(self): + # This test is for deprecated behavior and can be removed in v5 try: tmp_file = tempfile.mktemp() with open(tmp_file, "wb") as f: http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f) - AlbertTokenizer.from_pretrained(tmp_file) + _ = AlbertTokenizer.from_pretrained(tmp_file) finally: os.remove(tmp_file) + def test_legacy_load_from_url(self): + # This test is for deprecated behavior and can be removed in v5 + _ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert-base-v1/resolve/main/spiece.model") + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase):