Add tests for legacy load by url and fix bugs (#19078)

This commit is contained in:
Sylvain Gugger
2022-09-16 17:20:02 -04:00
committed by GitHub
parent ae219532e3
commit ca485e562b
9 changed files with 62 additions and 5 deletions

View File

@@ -30,6 +30,7 @@ from typing import List, Tuple, get_type_hints
from datasets import Dataset
from huggingface_hub import HfFolder, Repository, delete_repo, set_access_token
from huggingface_hub.file_download import http_get
from requests.exceptions import HTTPError
from transformers import is_tf_available, is_torch_available
from transformers.configuration_utils import PretrainedConfig
@@ -1927,6 +1928,24 @@ class UtilsFunctionsTest(unittest.TestCase):
# This check we did call the fake head request
mock_head.assert_called()
def test_load_from_one_file(self):
try:
tmp_file = tempfile.mktemp()
with open(tmp_file, "wb") as f:
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(tmp_file, config=config)
finally:
os.remove(tmp_file)
def test_legacy_load_from_url(self):
# This test is for deprecated behavior and can be removed in v5
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
_ = TFBertModel.from_pretrained(
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
)
# tests whether the unpack_inputs function behaves as expected
def test_unpack_inputs(self):
class DummyModel: