Add tests for legacy load by url and fix bugs (#19078)
This commit is contained in:
@@ -360,6 +360,12 @@ class ConfigTestUtils(unittest.TestCase):
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_legacy_load_from_url(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
_ = BertConfig.from_pretrained(
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
|
||||
)
|
||||
|
||||
|
||||
class ConfigurationVersioningTest(unittest.TestCase):
|
||||
def test_local_versioning(self):
|
||||
|
||||
@@ -182,6 +182,12 @@ class FeatureExtractorUtilTester(unittest.TestCase):
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_legacy_load_from_url(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json"
|
||||
)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class FeatureExtractorPushToHubTester(unittest.TestCase):
|
||||
|
||||
@@ -33,6 +33,7 @@ import numpy as np
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import HfFolder, delete_repo, set_access_token
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -2949,6 +2950,26 @@ class ModelUtilsTest(TestCasePlus):
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_load_from_one_file(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get(
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f
|
||||
)
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
_ = BertModel.from_pretrained(tmp_file, config=config)
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
def test_legacy_load_from_url(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
_ = BertModel.from_pretrained(
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
@@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints
|
||||
from datasets import Dataset
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
@@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_load_from_one_file(self):
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
_ = TFBertModel.from_pretrained(tmp_file, config=config)
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
def test_legacy_load_from_url(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
_ = TFBertModel.from_pretrained(
|
||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
|
||||
)
|
||||
|
||||
# tests whether the unpack_inputs function behaves as expected
|
||||
def test_unpack_inputs(self):
|
||||
class DummyModel:
|
||||
|
||||
@@ -3891,15 +3891,20 @@ class TokenizerUtilTester(unittest.TestCase):
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_legacy_load_from_one_file(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f)
|
||||
|
||||
AlbertTokenizer.from_pretrained(tmp_file)
|
||||
_ = AlbertTokenizer.from_pretrained(tmp_file)
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
def test_legacy_load_from_url(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
_ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert-base-v1/resolve/main/spiece.model")
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user