Split common test from core tests (#24284)
This commit is contained in:
@@ -16,80 +16,11 @@
|
|||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import shutil
|
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
import unittest.mock as mock
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
from transformers import is_torch_available
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
from transformers import AutoConfig, BertConfig, GPT2Config, is_torch_available
|
from .test_configuration_utils import config_common_kwargs
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|
||||||
|
|
||||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
|
||||||
|
|
||||||
|
|
||||||
config_common_kwargs = {
|
|
||||||
"return_dict": False,
|
|
||||||
"output_hidden_states": True,
|
|
||||||
"output_attentions": True,
|
|
||||||
"torchscript": True,
|
|
||||||
"torch_dtype": "float16",
|
|
||||||
"use_bfloat16": True,
|
|
||||||
"tf_legacy_loss": True,
|
|
||||||
"pruned_heads": {"a": 1},
|
|
||||||
"tie_word_embeddings": False,
|
|
||||||
"is_decoder": True,
|
|
||||||
"cross_attention_hidden_size": 128,
|
|
||||||
"add_cross_attention": True,
|
|
||||||
"tie_encoder_decoder": True,
|
|
||||||
"max_length": 50,
|
|
||||||
"min_length": 3,
|
|
||||||
"do_sample": True,
|
|
||||||
"early_stopping": True,
|
|
||||||
"num_beams": 3,
|
|
||||||
"num_beam_groups": 3,
|
|
||||||
"diversity_penalty": 0.5,
|
|
||||||
"temperature": 2.0,
|
|
||||||
"top_k": 10,
|
|
||||||
"top_p": 0.7,
|
|
||||||
"typical_p": 0.2,
|
|
||||||
"repetition_penalty": 0.8,
|
|
||||||
"length_penalty": 0.8,
|
|
||||||
"no_repeat_ngram_size": 5,
|
|
||||||
"encoder_no_repeat_ngram_size": 5,
|
|
||||||
"bad_words_ids": [1, 2, 3],
|
|
||||||
"num_return_sequences": 3,
|
|
||||||
"chunk_size_feed_forward": 5,
|
|
||||||
"output_scores": True,
|
|
||||||
"return_dict_in_generate": True,
|
|
||||||
"forced_bos_token_id": 2,
|
|
||||||
"forced_eos_token_id": 3,
|
|
||||||
"remove_invalid_values": True,
|
|
||||||
"architectures": ["BertModel"],
|
|
||||||
"finetuning_task": "translation",
|
|
||||||
"id2label": {0: "label"},
|
|
||||||
"label2id": {"label": "0"},
|
|
||||||
"tokenizer_class": "BertTokenizerFast",
|
|
||||||
"prefix": "prefix",
|
|
||||||
"bos_token_id": 6,
|
|
||||||
"pad_token_id": 7,
|
|
||||||
"eos_token_id": 8,
|
|
||||||
"sep_token_id": 9,
|
|
||||||
"decoder_start_token_id": 10,
|
|
||||||
"exponential_decay_length_penalty": (5, 1.01),
|
|
||||||
"suppress_tokens": [0, 1],
|
|
||||||
"begin_suppress_tokens": 2,
|
|
||||||
"task_specific_params": {"translation": "some_params"},
|
|
||||||
"problem_type": "regression",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigTester(object):
|
class ConfigTester(object):
|
||||||
@@ -220,200 +151,3 @@ class ConfigTester(object):
|
|||||||
self.create_and_test_config_with_num_labels()
|
self.create_and_test_config_with_num_labels()
|
||||||
self.check_config_can_be_init_without_params()
|
self.check_config_can_be_init_without_params()
|
||||||
self.check_config_arguments_init()
|
self.check_config_arguments_init()
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
|
||||||
class ConfigPushToHubTester(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-config")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-dynamic-config")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
config.push_to_hub("test-config", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
|
||||||
for k, v in config.to_dict().items():
|
|
||||||
if k != "transformers_version":
|
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-config")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
|
||||||
for k, v in config.to_dict().items():
|
|
||||||
if k != "transformers_version":
|
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
|
||||||
for k, v in config.to_dict().items():
|
|
||||||
if k != "transformers_version":
|
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
config.save_pretrained(
|
|
||||||
tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
|
||||||
for k, v in config.to_dict().items():
|
|
||||||
if k != "transformers_version":
|
|
||||||
self.assertEqual(v, getattr(new_config, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_dynamic_config(self):
|
|
||||||
CustomConfig.register_for_auto_class()
|
|
||||||
config = CustomConfig(attribute=42)
|
|
||||||
|
|
||||||
config.push_to_hub("test-dynamic-config", use_auth_token=self._token)
|
|
||||||
|
|
||||||
# This has added the proper auto_map field to the config
|
|
||||||
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
|
||||||
|
|
||||||
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
|
|
||||||
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
|
||||||
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
|
|
||||||
self.assertEqual(new_config.attribute, 42)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigTestUtils(unittest.TestCase):
|
|
||||||
def test_config_from_string(self):
|
|
||||||
c = GPT2Config()
|
|
||||||
|
|
||||||
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
|
||||||
n_embd = c.n_embd + 1 # int
|
|
||||||
resid_pdrop = c.resid_pdrop + 1.0 # float
|
|
||||||
scale_attn_weights = not c.scale_attn_weights # bool
|
|
||||||
summary_type = c.summary_type + "foo" # str
|
|
||||||
c.update_from_string(
|
|
||||||
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
|
||||||
)
|
|
||||||
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
|
||||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
|
||||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
|
||||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
|
||||||
|
|
||||||
def test_config_common_kwargs_is_complete(self):
|
|
||||||
base_config = PretrainedConfig()
|
|
||||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
|
||||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
|
||||||
self.assertListEqual(
|
|
||||||
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
|
||||||
)
|
|
||||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
|
||||||
if len(keys_with_defaults) > 0:
|
|
||||||
raise ValueError(
|
|
||||||
"The following keys are set with the default values in"
|
|
||||||
" `test_configuration_common.config_common_kwargs` pick another value for them:"
|
|
||||||
f" {', '.join(keys_with_defaults)}."
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_from_pretrained_subfolder(self):
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
# config is in subfolder, the following should not work without specifying the subfolder
|
|
||||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
|
||||||
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
|
||||||
|
|
||||||
self.assertIsNotNone(config)
|
|
||||||
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
def test_legacy_load_from_url(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
_ = BertConfig.from_pretrained(
|
|
||||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigurationVersioningTest(unittest.TestCase):
|
|
||||||
def test_local_versioning(self):
|
|
||||||
configuration = AutoConfig.from_pretrained("bert-base-cased")
|
|
||||||
configuration.configuration_files = ["config.4.0.0.json"]
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
configuration.save_pretrained(tmp_dir)
|
|
||||||
configuration.hidden_size = 2
|
|
||||||
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
|
||||||
|
|
||||||
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
|
||||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
|
||||||
self.assertEqual(new_configuration.hidden_size, 2)
|
|
||||||
|
|
||||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
|
||||||
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
|
||||||
configuration.configuration_files = ["config.42.0.0.json"]
|
|
||||||
configuration.hidden_size = 768
|
|
||||||
configuration.save_pretrained(tmp_dir)
|
|
||||||
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
|
||||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
|
||||||
self.assertEqual(new_configuration.hidden_size, 768)
|
|
||||||
|
|
||||||
def test_repo_versioning_before(self):
|
|
||||||
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
|
|
||||||
repo = "hf-internal-testing/test-two-configs"
|
|
||||||
|
|
||||||
import transformers as new_transformers
|
|
||||||
|
|
||||||
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
|
||||||
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
|
|
||||||
repo, return_unused_kwargs=True
|
|
||||||
)
|
|
||||||
self.assertEqual(new_configuration.hidden_size, 2)
|
|
||||||
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
|
|
||||||
self.assertDictEqual(kwargs, {})
|
|
||||||
|
|
||||||
# Testing an older version by monkey-patching the version in the module it's used.
|
|
||||||
import transformers as old_transformers
|
|
||||||
|
|
||||||
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
|
||||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
|
||||||
self.assertEqual(old_configuration.hidden_size, 768)
|
|
||||||
|
|||||||
286
tests/test_configuration_utils.py
Normal file
286
tests/test_configuration_utils.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import AutoConfig, BertConfig, GPT2Config
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
config_common_kwargs = {
|
||||||
|
"return_dict": False,
|
||||||
|
"output_hidden_states": True,
|
||||||
|
"output_attentions": True,
|
||||||
|
"torchscript": True,
|
||||||
|
"torch_dtype": "float16",
|
||||||
|
"use_bfloat16": True,
|
||||||
|
"tf_legacy_loss": True,
|
||||||
|
"pruned_heads": {"a": 1},
|
||||||
|
"tie_word_embeddings": False,
|
||||||
|
"is_decoder": True,
|
||||||
|
"cross_attention_hidden_size": 128,
|
||||||
|
"add_cross_attention": True,
|
||||||
|
"tie_encoder_decoder": True,
|
||||||
|
"max_length": 50,
|
||||||
|
"min_length": 3,
|
||||||
|
"do_sample": True,
|
||||||
|
"early_stopping": True,
|
||||||
|
"num_beams": 3,
|
||||||
|
"num_beam_groups": 3,
|
||||||
|
"diversity_penalty": 0.5,
|
||||||
|
"temperature": 2.0,
|
||||||
|
"top_k": 10,
|
||||||
|
"top_p": 0.7,
|
||||||
|
"typical_p": 0.2,
|
||||||
|
"repetition_penalty": 0.8,
|
||||||
|
"length_penalty": 0.8,
|
||||||
|
"no_repeat_ngram_size": 5,
|
||||||
|
"encoder_no_repeat_ngram_size": 5,
|
||||||
|
"bad_words_ids": [1, 2, 3],
|
||||||
|
"num_return_sequences": 3,
|
||||||
|
"chunk_size_feed_forward": 5,
|
||||||
|
"output_scores": True,
|
||||||
|
"return_dict_in_generate": True,
|
||||||
|
"forced_bos_token_id": 2,
|
||||||
|
"forced_eos_token_id": 3,
|
||||||
|
"remove_invalid_values": True,
|
||||||
|
"architectures": ["BertModel"],
|
||||||
|
"finetuning_task": "translation",
|
||||||
|
"id2label": {0: "label"},
|
||||||
|
"label2id": {"label": "0"},
|
||||||
|
"tokenizer_class": "BertTokenizerFast",
|
||||||
|
"prefix": "prefix",
|
||||||
|
"bos_token_id": 6,
|
||||||
|
"pad_token_id": 7,
|
||||||
|
"eos_token_id": 8,
|
||||||
|
"sep_token_id": 9,
|
||||||
|
"decoder_start_token_id": 10,
|
||||||
|
"exponential_decay_length_penalty": (5, 1.01),
|
||||||
|
"suppress_tokens": [0, 1],
|
||||||
|
"begin_suppress_tokens": 2,
|
||||||
|
"task_specific_params": {"translation": "some_params"},
|
||||||
|
"problem_type": "regression",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class ConfigPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-config")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-config")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
config.push_to_hub("test-config", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-config")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
config.push_to_hub("valid_org/test-config-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
config.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||||
|
for k, v in config.to_dict().items():
|
||||||
|
if k != "transformers_version":
|
||||||
|
self.assertEqual(v, getattr(new_config, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_dynamic_config(self):
|
||||||
|
CustomConfig.register_for_auto_class()
|
||||||
|
config = CustomConfig(attribute=42)
|
||||||
|
|
||||||
|
config.push_to_hub("test-dynamic-config", use_auth_token=self._token)
|
||||||
|
|
||||||
|
# This has added the proper auto_map field to the config
|
||||||
|
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
||||||
|
|
||||||
|
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
|
||||||
|
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
||||||
|
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
|
||||||
|
self.assertEqual(new_config.attribute, 42)
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigTestUtils(unittest.TestCase):
|
||||||
|
def test_config_from_string(self):
|
||||||
|
c = GPT2Config()
|
||||||
|
|
||||||
|
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
||||||
|
n_embd = c.n_embd + 1 # int
|
||||||
|
resid_pdrop = c.resid_pdrop + 1.0 # float
|
||||||
|
scale_attn_weights = not c.scale_attn_weights # bool
|
||||||
|
summary_type = c.summary_type + "foo" # str
|
||||||
|
c.update_from_string(
|
||||||
|
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
||||||
|
)
|
||||||
|
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
||||||
|
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||||
|
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||||
|
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||||
|
|
||||||
|
def test_config_common_kwargs_is_complete(self):
|
||||||
|
base_config = PretrainedConfig()
|
||||||
|
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||||
|
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||||
|
self.assertListEqual(
|
||||||
|
missing_keys, ["is_encoder_decoder", "_name_or_path", "_commit_hash", "transformers_version"]
|
||||||
|
)
|
||||||
|
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||||
|
if len(keys_with_defaults) > 0:
|
||||||
|
raise ValueError(
|
||||||
|
"The following keys are set with the default values in"
|
||||||
|
" `test_configuration_common.config_common_kwargs` pick another value for them:"
|
||||||
|
f" {', '.join(keys_with_defaults)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_from_pretrained_subfolder(self):
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
# config is in subfolder, the following should not work without specifying the subfolder
|
||||||
|
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
||||||
|
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
_ = BertConfig.from_pretrained(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/config.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_local_versioning(self):
|
||||||
|
configuration = AutoConfig.from_pretrained("bert-base-cased")
|
||||||
|
configuration.configuration_files = ["config.4.0.0.json"]
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
configuration.save_pretrained(tmp_dir)
|
||||||
|
configuration.hidden_size = 2
|
||||||
|
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
||||||
|
|
||||||
|
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
||||||
|
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(new_configuration.hidden_size, 2)
|
||||||
|
|
||||||
|
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||||
|
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||||
|
configuration.configuration_files = ["config.42.0.0.json"]
|
||||||
|
configuration.hidden_size = 768
|
||||||
|
configuration.save_pretrained(tmp_dir)
|
||||||
|
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||||
|
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||||
|
self.assertEqual(new_configuration.hidden_size, 768)
|
||||||
|
|
||||||
|
def test_repo_versioning_before(self):
|
||||||
|
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
|
||||||
|
repo = "hf-internal-testing/test-two-configs"
|
||||||
|
|
||||||
|
import transformers as new_transformers
|
||||||
|
|
||||||
|
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
||||||
|
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
|
||||||
|
repo, return_unused_kwargs=True
|
||||||
|
)
|
||||||
|
self.assertEqual(new_configuration.hidden_size, 2)
|
||||||
|
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
|
||||||
|
self.assertDictEqual(kwargs, {})
|
||||||
|
|
||||||
|
# Testing an older version by monkey-patching the version in the module it's used.
|
||||||
|
import transformers as old_transformers
|
||||||
|
|
||||||
|
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||||
|
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||||
|
self.assertEqual(old_configuration.hidden_size, 768)
|
||||||
@@ -16,25 +16,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
import unittest.mock as mock
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
from transformers.testing_utils import check_json_file_has_correct_format
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
|
||||||
from transformers.testing_utils import TOKEN, USER, check_json_file_has_correct_format, get_tests_dir, is_staging_test
|
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|
||||||
|
|
||||||
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
|
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractionSavingTestMixin:
|
class FeatureExtractionSavingTestMixin:
|
||||||
@@ -69,112 +53,3 @@ class FeatureExtractionSavingTestMixin:
|
|||||||
def test_init_without_params(self):
|
def test_init_without_params(self):
|
||||||
feat_extract = self.feature_extraction_class()
|
feat_extract = self.feature_extraction_class()
|
||||||
self.assertIsNotNone(feat_extract)
|
self.assertIsNotNone(feat_extract)
|
||||||
|
|
||||||
|
|
||||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
def test_legacy_load_from_url(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
_ = Wav2Vec2FeatureExtractor.from_pretrained(
|
|
||||||
"https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
|
||||||
class FeatureExtractorPushToHubTester(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-feature-extractor")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
|
||||||
feature_extractor.push_to_hub("test-feature-extractor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
|
||||||
for k, v in feature_extractor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-feature-extractor")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
feature_extractor.save_pretrained(
|
|
||||||
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
|
||||||
for k, v in feature_extractor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
|
||||||
feature_extractor.push_to_hub("valid_org/test-feature-extractor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
|
|
||||||
for k, v in feature_extractor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
feature_extractor.save_pretrained(
|
|
||||||
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
|
||||||
for k, v in feature_extractor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
|
||||||
CustomFeatureExtractor.register_for_auto_class()
|
|
||||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
|
||||||
|
|
||||||
feature_extractor.push_to_hub("test-dynamic-feature-extractor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
# This has added the proper auto_map field to the config
|
|
||||||
self.assertDictEqual(
|
|
||||||
feature_extractor.auto_map,
|
|
||||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
|
||||||
)
|
|
||||||
|
|
||||||
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
|
||||||
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
|
|
||||||
)
|
|
||||||
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
|
|
||||||
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
|
|
||||||
|
|||||||
144
tests/test_feature_extraction_utils.py
Normal file
144
tests/test_feature_extraction_utils.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2021 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||||
|
from transformers.testing_utils import TOKEN, USER, get_tests_dir, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
|
||||||
|
|
||||||
|
|
||||||
|
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
_ = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-wav2vec2/resolve/main/preprocessor_config.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class FeatureExtractorPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-feature-extractor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
feature_extractor.push_to_hub("test-feature-extractor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-feature-extractor")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
feature_extractor.save_pretrained(
|
||||||
|
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
feature_extractor.push_to_hub("valid_org/test-feature-extractor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
feature_extractor.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||||
|
for k, v in feature_extractor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||||
|
CustomFeatureExtractor.register_for_auto_class()
|
||||||
|
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||||
|
|
||||||
|
feature_extractor.push_to_hub("test-dynamic-feature-extractor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
# This has added the proper auto_map field to the config
|
||||||
|
self.assertDictEqual(
|
||||||
|
feature_extractor.auto_map,
|
||||||
|
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||||
|
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
|
||||||
|
)
|
||||||
|
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
|
||||||
|
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
|
||||||
@@ -13,36 +13,14 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
import unittest.mock as mock
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
from transformers.testing_utils import check_json_file_has_correct_format, require_torch, require_vision
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
from transformers import AutoImageProcessor, ViTImageProcessor
|
|
||||||
from transformers.testing_utils import (
|
|
||||||
TOKEN,
|
|
||||||
USER,
|
|
||||||
check_json_file_has_correct_format,
|
|
||||||
get_tests_dir,
|
|
||||||
is_staging_test,
|
|
||||||
require_torch,
|
|
||||||
require_vision,
|
|
||||||
)
|
|
||||||
from transformers.utils import is_torch_available, is_vision_available
|
from transformers.utils import is_torch_available, is_vision_available
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|
||||||
|
|
||||||
from test_module.custom_image_processing import CustomImageProcessor # noqa E402
|
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -51,9 +29,6 @@ if is_vision_available():
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify=False, torchify=False):
|
def prepare_image_inputs(image_processor_tester, equal_resolution=False, numpify=False, torchify=False):
|
||||||
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
"""This function prepares a list of PIL images, or a list of numpy arrays if one specifies numpify=True,
|
||||||
or a list of PyTorch tensors if one specifies torchify=True.
|
or a list of PyTorch tensors if one specifies torchify=True.
|
||||||
@@ -201,123 +176,3 @@ class ImageProcessingSavingTestMixin:
|
|||||||
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
self.assertEqual(encoding.pixel_values.device, torch.device("cpu"))
|
||||||
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
self.assertEqual(encoding.pixel_values.dtype, torch.float16)
|
||||||
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
self.assertEqual(encoding.input_ids.dtype, torch.long)
|
||||||
|
|
||||||
|
|
||||||
class ImageProcessorUtilTester(unittest.TestCase):
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
def test_legacy_load_from_url(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
_ = ViTImageProcessor.from_pretrained(
|
|
||||||
"https://huggingface.co/hf-internal-testing/tiny-random-vit/resolve/main/preprocessor_config.json"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
|
||||||
class ImageProcessorPushToHubTester(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-image-processor")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
|
||||||
image_processor.push_to_hub("test-image-processor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
|
||||||
for k, v in image_processor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_image_processor, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-image-processor")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
image_processor.save_pretrained(
|
|
||||||
tmp_dir, repo_id="test-image-processor", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
|
||||||
for k, v in image_processor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_image_processor, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
|
||||||
image_processor.push_to_hub("valid_org/test-image-processor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor")
|
|
||||||
for k, v in image_processor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_image_processor, k))
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-image-processor")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
image_processor.save_pretrained(
|
|
||||||
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org")
|
|
||||||
for k, v in image_processor.__dict__.items():
|
|
||||||
self.assertEqual(v, getattr(new_image_processor, k))
|
|
||||||
|
|
||||||
def test_push_to_hub_dynamic_image_processor(self):
|
|
||||||
CustomImageProcessor.register_for_auto_class()
|
|
||||||
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
|
||||||
|
|
||||||
image_processor.push_to_hub("test-dynamic-image-processor", use_auth_token=self._token)
|
|
||||||
|
|
||||||
# This has added the proper auto_map field to the config
|
|
||||||
self.assertDictEqual(
|
|
||||||
image_processor.auto_map,
|
|
||||||
{"ImageProcessor": "custom_image_processing.CustomImageProcessor"},
|
|
||||||
)
|
|
||||||
|
|
||||||
new_image_processor = AutoImageProcessor.from_pretrained(
|
|
||||||
f"{USER}/test-dynamic-image-processor", trust_remote_code=True
|
|
||||||
)
|
|
||||||
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
|
|
||||||
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
|
|
||||||
|
|
||||||
def test_image_processor_from_pretrained_subfolder(self):
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
# config is in subfolder, the following should not work without specifying the subfolder
|
|
||||||
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
|
|
||||||
|
|
||||||
config = AutoImageProcessor.from_pretrained(
|
|
||||||
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(config)
|
|
||||||
|
|||||||
154
tests/test_image_processing_utils.py
Normal file
154
tests/test_image_processing_utils.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2023 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import AutoImageProcessor, ViTImageProcessor
|
||||||
|
from transformers.testing_utils import TOKEN, USER, get_tests_dir, is_staging_test
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_image_processing import CustomImageProcessor # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
|
||||||
|
|
||||||
|
|
||||||
|
class ImageProcessorUtilTester(unittest.TestCase):
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
_ = ViTImageProcessor.from_pretrained(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-vit/resolve/main/preprocessor_config.json"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class ImageProcessorPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-image-processor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||||
|
image_processor.push_to_hub("test-image-processor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
||||||
|
for k, v in image_processor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_image_processor, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-image-processor")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
image_processor.save_pretrained(
|
||||||
|
tmp_dir, repo_id="test-image-processor", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
||||||
|
for k, v in image_processor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_image_processor, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||||
|
image_processor.push_to_hub("valid_org/test-image-processor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor")
|
||||||
|
for k, v in image_processor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_image_processor, k))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-image-processor")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
image_processor.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org")
|
||||||
|
for k, v in image_processor.__dict__.items():
|
||||||
|
self.assertEqual(v, getattr(new_image_processor, k))
|
||||||
|
|
||||||
|
def test_push_to_hub_dynamic_image_processor(self):
|
||||||
|
CustomImageProcessor.register_for_auto_class()
|
||||||
|
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||||
|
|
||||||
|
image_processor.push_to_hub("test-dynamic-image-processor", use_auth_token=self._token)
|
||||||
|
|
||||||
|
# This has added the proper auto_map field to the config
|
||||||
|
self.assertDictEqual(
|
||||||
|
image_processor.auto_map,
|
||||||
|
{"ImageProcessor": "custom_image_processing.CustomImageProcessor"},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_image_processor = AutoImageProcessor.from_pretrained(
|
||||||
|
f"{USER}/test-dynamic-image-processor", trust_remote_code=True
|
||||||
|
)
|
||||||
|
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
|
||||||
|
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
|
||||||
|
|
||||||
|
def test_image_processor_from_pretrained_subfolder(self):
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
# config is in subfolder, the following should not work without specifying the subfolder
|
||||||
|
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
|
||||||
|
|
||||||
|
config = AutoImageProcessor.from_pretrained(
|
||||||
|
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertIsNotNone(config)
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -17,25 +17,14 @@ import inspect
|
|||||||
import json
|
import json
|
||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
from transformers import BertConfig, is_flax_available, is_torch_available
|
from transformers import is_flax_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import CaptureLogger, is_pt_flax_cross_test, require_flax, torch_device
|
||||||
TOKEN,
|
|
||||||
USER,
|
|
||||||
CaptureLogger,
|
|
||||||
is_pt_flax_cross_test,
|
|
||||||
is_staging_test,
|
|
||||||
require_flax,
|
|
||||||
torch_device,
|
|
||||||
)
|
|
||||||
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ModelOutput
|
||||||
|
|
||||||
@@ -69,14 +58,6 @@ if is_torch_available():
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
|
||||||
configs_no_init = copy.deepcopy(config)
|
|
||||||
for key in configs_no_init.__dict__.keys():
|
|
||||||
if "_range" in key or "_std" in key or "initializer_factor" in key:
|
|
||||||
setattr(configs_no_init, key, 1e-10)
|
|
||||||
return configs_no_init
|
|
||||||
|
|
||||||
|
|
||||||
def ids_tensor(shape, vocab_size, rng=None):
|
def ids_tensor(shape, vocab_size, rng=None):
|
||||||
"""Creates a random int32 tensor of the shape within the vocab size."""
|
"""Creates a random int32 tensor of the shape within the vocab size."""
|
||||||
if rng is None:
|
if rng is None:
|
||||||
@@ -1164,155 +1145,3 @@ class FlaxModelTesterMixin:
|
|||||||
# ensure that the outputs remain precisely equal
|
# ensure that the outputs remain precisely equal
|
||||||
for output, remat_output in zip(outputs, remat_outputs):
|
for output, remat_output in zip(outputs, remat_outputs):
|
||||||
self.assertTrue((output == remat_output).all())
|
self.assertTrue((output == remat_output).all())
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
|
||||||
@is_staging_test
|
|
||||||
class FlaxModelPushToHubTester(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-model-flax")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
model = FlaxBertModel(config)
|
|
||||||
model.push_to_hub("test-model-flax", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
|
||||||
new_params = flatten_dict(unfreeze(new_model.params))
|
|
||||||
|
|
||||||
for key in base_params.keys():
|
|
||||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-model-flax")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
|
||||||
new_params = flatten_dict(unfreeze(new_model.params))
|
|
||||||
|
|
||||||
for key in base_params.keys():
|
|
||||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
model = FlaxBertModel(config)
|
|
||||||
model.push_to_hub("valid_org/test-model-flax-org", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
|
||||||
new_params = flatten_dict(unfreeze(new_model.params))
|
|
||||||
|
|
||||||
for key in base_params.keys():
|
|
||||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(
|
|
||||||
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
|
||||||
|
|
||||||
base_params = flatten_dict(unfreeze(model.params))
|
|
||||||
new_params = flatten_dict(unfreeze(new_model.params))
|
|
||||||
|
|
||||||
for key in base_params.keys():
|
|
||||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
|
||||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
|
||||||
|
|
||||||
|
|
||||||
def check_models_equal(model1, model2):
|
|
||||||
models_are_equal = True
|
|
||||||
flat_params_1 = flatten_dict(model1.params)
|
|
||||||
flat_params_2 = flatten_dict(model2.params)
|
|
||||||
for key in flat_params_1.keys():
|
|
||||||
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
|
||||||
models_are_equal = False
|
|
||||||
|
|
||||||
return models_are_equal
|
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
|
||||||
class FlaxModelUtilsTest(unittest.TestCase):
|
|
||||||
def test_model_from_pretrained_subfolder(self):
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
|
||||||
model = FlaxBertModel(config)
|
|
||||||
|
|
||||||
subfolder = "bert"
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
|
||||||
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
|
||||||
|
|
||||||
self.assertTrue(check_models_equal(model, model_loaded))
|
|
||||||
|
|
||||||
def test_model_from_pretrained_subfolder_sharded(self):
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
|
||||||
model = FlaxBertModel(config)
|
|
||||||
|
|
||||||
subfolder = "bert"
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
|
||||||
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
|
||||||
|
|
||||||
self.assertTrue(check_models_equal(model, model_loaded))
|
|
||||||
|
|
||||||
def test_model_from_pretrained_hub_subfolder(self):
|
|
||||||
subfolder = "bert"
|
|
||||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
|
||||||
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
_ = FlaxBertModel.from_pretrained(model_id)
|
|
||||||
|
|
||||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
|
||||||
|
|
||||||
self.assertIsNotNone(model)
|
|
||||||
|
|
||||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
|
||||||
subfolder = "bert"
|
|
||||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
|
||||||
with self.assertRaises(OSError):
|
|
||||||
_ = FlaxBertModel.from_pretrained(model_id)
|
|
||||||
|
|
||||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
|
||||||
|
|
||||||
self.assertIsNotNone(model)
|
|
||||||
|
|||||||
186
tests/test_modeling_flax_utils.py
Normal file
186
tests/test_modeling_flax_utils.py
Normal file
@@ -0,0 +1,186 @@
|
|||||||
|
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import BertConfig, is_flax_available
|
||||||
|
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_flax
|
||||||
|
|
||||||
|
|
||||||
|
if is_flax_available():
|
||||||
|
import os
|
||||||
|
|
||||||
|
from flax.core.frozen_dict import unfreeze
|
||||||
|
from flax.traverse_util import flatten_dict
|
||||||
|
|
||||||
|
from transformers import FlaxBertModel
|
||||||
|
|
||||||
|
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
@is_staging_test
|
||||||
|
class FlaxModelPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-model-flax")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
model.push_to_hub("test-model-flax", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-model-flax")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
model.push_to_hub("valid_org/test-model-flax-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||||
|
|
||||||
|
base_params = flatten_dict(unfreeze(model.params))
|
||||||
|
new_params = flatten_dict(unfreeze(new_model.params))
|
||||||
|
|
||||||
|
for key in base_params.keys():
|
||||||
|
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||||
|
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||||
|
|
||||||
|
|
||||||
|
def check_models_equal(model1, model2):
|
||||||
|
models_are_equal = True
|
||||||
|
flat_params_1 = flatten_dict(model1.params)
|
||||||
|
flat_params_2 = flatten_dict(model2.params)
|
||||||
|
for key in flat_params_1.keys():
|
||||||
|
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
||||||
|
models_are_equal = False
|
||||||
|
|
||||||
|
return models_are_equal
|
||||||
|
|
||||||
|
|
||||||
|
@require_flax
|
||||||
|
class FlaxModelUtilsTest(unittest.TestCase):
|
||||||
|
def test_model_from_pretrained_subfolder(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_subfolder_sharded(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||||
|
model = FlaxBertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = FlaxBertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
@@ -23,42 +23,24 @@ import os
|
|||||||
import random
|
import random
|
||||||
import tempfile
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
from math import isnan
|
from math import isnan
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import HfFolder, Repository, delete_repo
|
|
||||||
from huggingface_hub.file_download import http_get
|
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import ( # noqa: F401
|
from transformers.testing_utils import ( # noqa: F401
|
||||||
TOKEN,
|
|
||||||
USER,
|
|
||||||
CaptureLogger,
|
CaptureLogger,
|
||||||
CaptureStdout,
|
|
||||||
_tf_gpu_memory_limit,
|
_tf_gpu_memory_limit,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
|
||||||
require_safetensors,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_tf2onnx,
|
require_tf2onnx,
|
||||||
slow,
|
slow,
|
||||||
tooslow,
|
|
||||||
torch_device,
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import (
|
from transformers.utils import CONFIG_NAME, GENERATION_CONFIG_NAME, logging
|
||||||
CONFIG_NAME,
|
|
||||||
GENERATION_CONFIG_NAME,
|
|
||||||
SAFE_WEIGHTS_NAME,
|
|
||||||
TF2_WEIGHTS_INDEX_NAME,
|
|
||||||
TF2_WEIGHTS_NAME,
|
|
||||||
logging,
|
|
||||||
)
|
|
||||||
from transformers.utils.generic import ModelOutput
|
from transformers.utils.generic import ModelOutput
|
||||||
|
|
||||||
|
|
||||||
@@ -66,7 +48,6 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import h5py
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@@ -85,17 +66,8 @@ if is_tf_available():
|
|||||||
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||||
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
TF_MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING,
|
||||||
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||||
BertConfig,
|
|
||||||
PreTrainedModel,
|
|
||||||
PushToHubCallback,
|
|
||||||
RagRetriever,
|
|
||||||
TFAutoModel,
|
TFAutoModel,
|
||||||
TFAutoModelForSequenceClassification,
|
TFAutoModelForSequenceClassification,
|
||||||
TFBertForMaskedLM,
|
|
||||||
TFBertForSequenceClassification,
|
|
||||||
TFBertModel,
|
|
||||||
TFPreTrainedModel,
|
|
||||||
TFRagModel,
|
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
)
|
)
|
||||||
from transformers.generation import (
|
from transformers.generation import (
|
||||||
@@ -108,8 +80,6 @@ if is_tf_available():
|
|||||||
TFSampleDecoderOnlyOutput,
|
TFSampleDecoderOnlyOutput,
|
||||||
TFSampleEncoderDecoderOutput,
|
TFSampleEncoderDecoderOutput,
|
||||||
)
|
)
|
||||||
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
|
|
||||||
from transformers.tf_utils import stable_softmax
|
|
||||||
|
|
||||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||||
|
|
||||||
@@ -130,8 +100,6 @@ if is_tf_available():
|
|||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from transformers import BertModel
|
|
||||||
|
|
||||||
|
|
||||||
def _config_zero_init(config):
|
def _config_zero_init(config):
|
||||||
configs_no_init = copy.deepcopy(config)
|
configs_no_init = copy.deepcopy(config)
|
||||||
@@ -1995,544 +1963,3 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
|
|||||||
values.append(rng.random() * scale)
|
values.append(rng.random() * scale)
|
||||||
|
|
||||||
return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape)
|
return tf.reshape(tf.constant(values, dtype=dtype if dtype is not None else tf.float32), shape=shape)
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
class UtilsFunctionsTest(unittest.TestCase):
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
def test_load_from_one_file(self):
|
|
||||||
try:
|
|
||||||
tmp_file = tempfile.mktemp()
|
|
||||||
with open(tmp_file, "wb") as f:
|
|
||||||
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
|
|
||||||
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
_ = TFBertModel.from_pretrained(tmp_file, config=config)
|
|
||||||
finally:
|
|
||||||
os.remove(tmp_file)
|
|
||||||
|
|
||||||
def test_legacy_load_from_url(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
_ = TFBertModel.from_pretrained(
|
|
||||||
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
|
|
||||||
)
|
|
||||||
|
|
||||||
# tests whether the unpack_inputs function behaves as expected
|
|
||||||
def test_unpack_inputs(self):
|
|
||||||
class DummyModel:
|
|
||||||
def __init__(self):
|
|
||||||
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
|
||||||
self.config = PretrainedConfig(**config_kwargs)
|
|
||||||
self.main_input_name = "input_ids"
|
|
||||||
|
|
||||||
@unpack_inputs
|
|
||||||
def call(
|
|
||||||
self,
|
|
||||||
input_ids=None,
|
|
||||||
past_key_values=None,
|
|
||||||
output_attentions=None,
|
|
||||||
output_hidden_states=None,
|
|
||||||
return_dict=None,
|
|
||||||
):
|
|
||||||
return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
|
|
||||||
|
|
||||||
@unpack_inputs
|
|
||||||
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
|
||||||
return pixel_values, output_attentions, output_hidden_states, return_dict
|
|
||||||
|
|
||||||
dummy_model = DummyModel()
|
|
||||||
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
|
|
||||||
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
|
|
||||||
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
|
|
||||||
|
|
||||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
|
||||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
|
|
||||||
tf.debugging.assert_equal(output[0], input_ids)
|
|
||||||
tf.debugging.assert_equal(output[1], past_key_values)
|
|
||||||
self.assertFalse(output[2])
|
|
||||||
self.assertFalse(output[3])
|
|
||||||
self.assertFalse(output[4])
|
|
||||||
|
|
||||||
# test case 2: Same as above, but with positional arguments.
|
|
||||||
output = dummy_model.call(input_ids, past_key_values)
|
|
||||||
tf.debugging.assert_equal(output[0], input_ids)
|
|
||||||
tf.debugging.assert_equal(output[1], past_key_values)
|
|
||||||
self.assertFalse(output[2])
|
|
||||||
self.assertFalse(output[3])
|
|
||||||
self.assertFalse(output[4])
|
|
||||||
|
|
||||||
# test case 3: We can also pack everything in the first input.
|
|
||||||
output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
|
|
||||||
tf.debugging.assert_equal(output[0], input_ids)
|
|
||||||
tf.debugging.assert_equal(output[1], past_key_values)
|
|
||||||
self.assertFalse(output[2])
|
|
||||||
self.assertFalse(output[3])
|
|
||||||
self.assertFalse(output[4])
|
|
||||||
|
|
||||||
# test case 4: Explicit boolean arguments should override the config.
|
|
||||||
output = dummy_model.call(
|
|
||||||
input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
|
|
||||||
)
|
|
||||||
tf.debugging.assert_equal(output[0], input_ids)
|
|
||||||
tf.debugging.assert_equal(output[1], past_key_values)
|
|
||||||
self.assertFalse(output[2])
|
|
||||||
self.assertFalse(output[3])
|
|
||||||
self.assertTrue(output[4])
|
|
||||||
|
|
||||||
# test case 5: Unexpected arguments should raise an exception.
|
|
||||||
with self.assertRaises(ValueError):
|
|
||||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
|
|
||||||
|
|
||||||
# test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
|
|
||||||
# decorated function as its main input.
|
|
||||||
output = dummy_model.foo(pixel_values=pixel_values)
|
|
||||||
tf.debugging.assert_equal(output[0], pixel_values)
|
|
||||||
self.assertFalse(output[1])
|
|
||||||
self.assertFalse(output[2])
|
|
||||||
self.assertFalse(output[3])
|
|
||||||
|
|
||||||
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
|
||||||
def test_xla_stable_softmax(self):
|
|
||||||
large_penalty = -1e9
|
|
||||||
n_tokens = 10
|
|
||||||
batch_size = 8
|
|
||||||
|
|
||||||
def masked_softmax(x, boolean_mask):
|
|
||||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
|
||||||
masked_x = x + numerical_mask
|
|
||||||
return stable_softmax(masked_x)
|
|
||||||
|
|
||||||
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
|
|
||||||
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
|
|
||||||
x = tf.random.normal((batch_size, n_tokens))
|
|
||||||
|
|
||||||
# Same outcome regardless of the boolean mask here
|
|
||||||
masked_tokens = random.randint(0, n_tokens)
|
|
||||||
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
|
|
||||||
|
|
||||||
# We can randomly mask a random numerical input OUTSIDE XLA
|
|
||||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
|
||||||
masked_x = x + numerical_mask
|
|
||||||
xla_out = xla_stable_softmax(masked_x)
|
|
||||||
out = stable_softmax(masked_x)
|
|
||||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
|
||||||
|
|
||||||
# The stable softmax has the same output as the original softmax
|
|
||||||
unstable_out = tf.nn.softmax(masked_x)
|
|
||||||
assert tf.experimental.numpy.allclose(unstable_out, out)
|
|
||||||
|
|
||||||
# We can randomly mask a random numerical input INSIDE XLA
|
|
||||||
xla_out = xla_masked_softmax(x, boolean_mask)
|
|
||||||
out = masked_softmax(x, boolean_mask)
|
|
||||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
|
||||||
|
|
||||||
def test_checkpoint_sharding_from_hub(self):
|
|
||||||
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
|
||||||
# the model above is the same as the model below, just a sharded version.
|
|
||||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
|
||||||
|
|
||||||
def test_sharded_checkpoint_with_prefix(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
|
|
||||||
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
|
|
||||||
for p1, p2 in zip(model.weights, sharded_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
self.assertTrue(p1.name.startswith("a/b/"))
|
|
||||||
self.assertTrue(p2.name.startswith("a/b/"))
|
|
||||||
|
|
||||||
def test_sharded_checkpoint_transfer(self):
|
|
||||||
# If this doesn't throw an error then the test passes
|
|
||||||
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
def test_checkpoint_sharding_local_from_pt(self):
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
|
|
||||||
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
|
|
||||||
# the model above is the same as the model below, just a sharded pytorch version.
|
|
||||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
def test_checkpoint_loading_with_prefix_from_pt(self):
|
|
||||||
model = TFBertModel.from_pretrained(
|
|
||||||
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
|
|
||||||
)
|
|
||||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
|
||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
self.assertTrue(p1.name.startswith("a/b/"))
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
def test_checkpoint_sharding_hub_from_pt(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
|
||||||
# the model above is the same as the model below, just a sharded pytorch version.
|
|
||||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
|
||||||
|
|
||||||
def test_shard_checkpoint(self):
|
|
||||||
# This is the model we will use, total size 340,000 bytes.
|
|
||||||
model = tf.keras.Sequential(
|
|
||||||
[
|
|
||||||
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
|
|
||||||
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
|
|
||||||
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
|
|
||||||
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
|
|
||||||
]
|
|
||||||
)
|
|
||||||
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
|
||||||
model(inputs)
|
|
||||||
weights = model.weights
|
|
||||||
weights_dict = {w.name: w for w in weights}
|
|
||||||
with self.subTest("No shard when max size is bigger than model size"):
|
|
||||||
shards, index = tf_shard_checkpoint(weights)
|
|
||||||
self.assertIsNone(index)
|
|
||||||
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
|
||||||
|
|
||||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
|
||||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
|
||||||
# Split is first two layers then last two.
|
|
||||||
self.assertDictEqual(
|
|
||||||
index,
|
|
||||||
{
|
|
||||||
"metadata": {"total_size": 340000},
|
|
||||||
"weight_map": {
|
|
||||||
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
|
||||||
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
|
||||||
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
|
||||||
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
|
||||||
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
|
||||||
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
|
||||||
|
|
||||||
with self.subTest("Test sharding with weights bigger than max size"):
|
|
||||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
|
||||||
# Split is first layer, second layer then last 2.
|
|
||||||
self.assertDictEqual(
|
|
||||||
index,
|
|
||||||
{
|
|
||||||
"metadata": {"total_size": 340000},
|
|
||||||
"weight_map": {
|
|
||||||
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
|
||||||
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
|
||||||
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
|
||||||
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
shard1 = [weights_dict["dense/kernel:0"]]
|
|
||||||
shard2 = [weights_dict["dense_1/kernel:0"]]
|
|
||||||
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
|
||||||
self.assertDictEqual(
|
|
||||||
shards,
|
|
||||||
{
|
|
||||||
"tf_model-00001-of-00003.h5": shard1,
|
|
||||||
"tf_model-00002-of-00003.h5": shard2,
|
|
||||||
"tf_model-00003-of-00003.h5": shard3,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_special_layer_name_sharding(self):
|
|
||||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
|
||||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
|
||||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
|
||||||
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
|
||||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
|
||||||
assert np.allclose(p1.numpy(), p2.numpy())
|
|
||||||
|
|
||||||
def test_checkpoint_sharding_local(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
|
||||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
|
||||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
|
||||||
|
|
||||||
# Get each shard file and its size
|
|
||||||
shard_to_size = {}
|
|
||||||
for shard in os.listdir(tmp_dir):
|
|
||||||
if shard.endswith(".h5"):
|
|
||||||
shard_file = os.path.join(tmp_dir, shard)
|
|
||||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
|
||||||
|
|
||||||
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
|
||||||
# Check there is an index but no regular weight file
|
|
||||||
self.assertTrue(os.path.isfile(index_file))
|
|
||||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
|
||||||
|
|
||||||
# Check a file is bigger than max_size only when it has a single weight
|
|
||||||
for shard_file, size in shard_to_size.items():
|
|
||||||
if max_size.endswith("kiB"):
|
|
||||||
max_size_int = int(max_size[:-3]) * 2**10
|
|
||||||
else:
|
|
||||||
max_size_int = int(max_size[:-2]) * 10**3
|
|
||||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
|
||||||
# the size asked for (since we count parameters)
|
|
||||||
if size >= max_size_int + 50000:
|
|
||||||
with h5py.File(shard_file, "r") as state_file:
|
|
||||||
self.assertEqual(len(state_file), 1)
|
|
||||||
|
|
||||||
# Check the index and the shard files found match
|
|
||||||
with open(index_file, "r", encoding="utf-8") as f:
|
|
||||||
index = json.loads(f.read())
|
|
||||||
|
|
||||||
all_shards = set(index["weight_map"].values())
|
|
||||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
|
|
||||||
self.assertSetEqual(all_shards, shards_found)
|
|
||||||
|
|
||||||
# Finally, check the model can be reloaded
|
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
model.build()
|
|
||||||
new_model.build()
|
|
||||||
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
|
|
||||||
@slow
|
|
||||||
def test_save_pretrained_signatures(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
# Short custom TF signature function.
|
|
||||||
# `input_signature` is specific to BERT.
|
|
||||||
@tf.function(
|
|
||||||
input_signature=[
|
|
||||||
[
|
|
||||||
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
|
|
||||||
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
|
|
||||||
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
|
|
||||||
]
|
|
||||||
]
|
|
||||||
)
|
|
||||||
def serving_fn(input):
|
|
||||||
return model(input)
|
|
||||||
|
|
||||||
# Using default signature (default behavior) overrides 'serving_default'
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
|
|
||||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
|
||||||
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
|
|
||||||
|
|
||||||
# Providing custom signature function
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
|
|
||||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
|
||||||
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
|
|
||||||
|
|
||||||
# Providing multiple custom signature function
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(
|
|
||||||
tmp_dir,
|
|
||||||
saved_model=True,
|
|
||||||
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
|
|
||||||
)
|
|
||||||
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
|
||||||
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
|
||||||
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
|
||||||
|
|
||||||
@require_safetensors
|
|
||||||
def test_safetensors_save_and_load(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
|
||||||
# No tf_model.h5 file, only a model.safetensors
|
|
||||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
|
||||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
# Check models are equal
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
def test_safetensors_save_and_load_pt_to_tf(self):
|
|
||||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
|
|
||||||
# Check we have a model.safetensors file
|
|
||||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
# Check models are equal
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
|
|
||||||
@require_safetensors
|
|
||||||
def test_safetensors_load_from_hub(self):
|
|
||||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
# Can load from the TF-formatted checkpoint
|
|
||||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
|
|
||||||
|
|
||||||
# Check models are equal
|
|
||||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
|
|
||||||
# Can load from the PyTorch-formatted checkpoint
|
|
||||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
|
||||||
|
|
||||||
# Check models are equal
|
|
||||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
|
||||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
|
||||||
|
|
||||||
|
|
||||||
@require_tf
|
|
||||||
@is_staging_test
|
|
||||||
class TFModelPushToHubTester(unittest.TestCase):
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-model-tf")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
model = TFBertModel(config)
|
|
||||||
# Make sure model is properly initialized
|
|
||||||
model.build()
|
|
||||||
|
|
||||||
logging.set_verbosity_info()
|
|
||||||
logger = logging.get_logger("transformers.utils.hub")
|
|
||||||
with CaptureLogger(logger) as cl:
|
|
||||||
model.push_to_hub("test-model-tf", use_auth_token=self._token)
|
|
||||||
logging.set_verbosity_warning()
|
|
||||||
# Check the model card was created and uploaded.
|
|
||||||
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
|
||||||
models_equal = True
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
if not tf.math.reduce_all(p1 == p2):
|
|
||||||
models_equal = False
|
|
||||||
break
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-model-tf")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
|
||||||
models_equal = True
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
if not tf.math.reduce_all(p1 == p2):
|
|
||||||
models_equal = False
|
|
||||||
break
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
@is_pt_tf_cross_test
|
|
||||||
def test_push_to_hub_callback(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
model = TFBertForMaskedLM(config)
|
|
||||||
model.compile()
|
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
push_to_hub_callback = PushToHubCallback(
|
|
||||||
output_dir=tmp_dir,
|
|
||||||
hub_model_id="test-model-tf-callback",
|
|
||||||
hub_token=self._token,
|
|
||||||
)
|
|
||||||
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
|
|
||||||
|
|
||||||
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
|
|
||||||
models_equal = True
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
if not tf.math.reduce_all(p1 == p2):
|
|
||||||
models_equal = False
|
|
||||||
break
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
|
||||||
tf_push_to_hub_params.pop("base_model_card_args")
|
|
||||||
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
|
||||||
pt_push_to_hub_params.pop("deprecated_kwargs")
|
|
||||||
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
config = BertConfig(
|
|
||||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
|
||||||
)
|
|
||||||
model = TFBertModel(config)
|
|
||||||
# Make sure model is properly initialized
|
|
||||||
model.build()
|
|
||||||
|
|
||||||
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
|
||||||
models_equal = True
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
if not tf.math.reduce_all(p1 == p2):
|
|
||||||
models_equal = False
|
|
||||||
break
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
model.save_pretrained(
|
|
||||||
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-tf-org"
|
|
||||||
)
|
|
||||||
|
|
||||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
|
||||||
models_equal = True
|
|
||||||
for p1, p2 in zip(model.weights, new_model.weights):
|
|
||||||
if not tf.math.reduce_all(p1 == p2):
|
|
||||||
models_equal = False
|
|
||||||
break
|
|
||||||
self.assertTrue(models_equal)
|
|
||||||
|
|||||||
627
tests/test_modeling_tf_utils.py
Normal file
627
tests/test_modeling_tf_utils.py
Normal file
@@ -0,0 +1,627 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import inspect
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, Repository, delete_repo
|
||||||
|
from huggingface_hub.file_download import http_get
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import is_tf_available, is_torch_available
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.testing_utils import ( # noqa: F401
|
||||||
|
TOKEN,
|
||||||
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
|
_tf_gpu_memory_limit,
|
||||||
|
is_pt_tf_cross_test,
|
||||||
|
is_staging_test,
|
||||||
|
require_safetensors,
|
||||||
|
require_tf,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
from transformers.utils import SAFE_WEIGHTS_NAME, TF2_WEIGHTS_INDEX_NAME, TF2_WEIGHTS_NAME, logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import h5py
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
BertConfig,
|
||||||
|
PreTrainedModel,
|
||||||
|
PushToHubCallback,
|
||||||
|
RagRetriever,
|
||||||
|
TFBertForMaskedLM,
|
||||||
|
TFBertForSequenceClassification,
|
||||||
|
TFBertModel,
|
||||||
|
TFPreTrainedModel,
|
||||||
|
TFRagModel,
|
||||||
|
)
|
||||||
|
from transformers.modeling_tf_utils import tf_shard_checkpoint, unpack_inputs
|
||||||
|
from transformers.tf_utils import stable_softmax
|
||||||
|
|
||||||
|
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||||
|
|
||||||
|
if _tf_gpu_memory_limit is not None:
|
||||||
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
|
for gpu in gpus:
|
||||||
|
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||||
|
try:
|
||||||
|
tf.config.set_logical_device_configuration(
|
||||||
|
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||||
|
)
|
||||||
|
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||||
|
print("Logical GPUs", logical_gpus)
|
||||||
|
except RuntimeError as e:
|
||||||
|
# Virtual devices must be set before GPUs have been initialized
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
from transformers import BertModel
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
class TFModelUtilsTest(unittest.TestCase):
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_load_from_one_file(self):
|
||||||
|
try:
|
||||||
|
tmp_file = tempfile.mktemp()
|
||||||
|
with open(tmp_file, "wb") as f:
|
||||||
|
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", f)
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
_ = TFBertModel.from_pretrained(tmp_file, config=config)
|
||||||
|
finally:
|
||||||
|
os.remove(tmp_file)
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
_ = TFBertModel.from_pretrained(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/tf_model.h5", config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
# tests whether the unpack_inputs function behaves as expected
|
||||||
|
def test_unpack_inputs(self):
|
||||||
|
class DummyModel:
|
||||||
|
def __init__(self):
|
||||||
|
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||||
|
self.config = PretrainedConfig(**config_kwargs)
|
||||||
|
self.main_input_name = "input_ids"
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
past_key_values=None,
|
||||||
|
output_attentions=None,
|
||||||
|
output_hidden_states=None,
|
||||||
|
return_dict=None,
|
||||||
|
):
|
||||||
|
return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
|
||||||
|
|
||||||
|
@unpack_inputs
|
||||||
|
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
||||||
|
return pixel_values, output_attentions, output_hidden_states, return_dict
|
||||||
|
|
||||||
|
dummy_model = DummyModel()
|
||||||
|
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
|
||||||
|
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
|
||||||
|
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
|
||||||
|
|
||||||
|
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past_key_values)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 2: Same as above, but with positional arguments.
|
||||||
|
output = dummy_model.call(input_ids, past_key_values)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past_key_values)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 3: We can also pack everything in the first input.
|
||||||
|
output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past_key_values)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertFalse(output[4])
|
||||||
|
|
||||||
|
# test case 4: Explicit boolean arguments should override the config.
|
||||||
|
output = dummy_model.call(
|
||||||
|
input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
|
||||||
|
)
|
||||||
|
tf.debugging.assert_equal(output[0], input_ids)
|
||||||
|
tf.debugging.assert_equal(output[1], past_key_values)
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
self.assertTrue(output[4])
|
||||||
|
|
||||||
|
# test case 5: Unexpected arguments should raise an exception.
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
|
||||||
|
|
||||||
|
# test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
|
||||||
|
# decorated function as its main input.
|
||||||
|
output = dummy_model.foo(pixel_values=pixel_values)
|
||||||
|
tf.debugging.assert_equal(output[0], pixel_values)
|
||||||
|
self.assertFalse(output[1])
|
||||||
|
self.assertFalse(output[2])
|
||||||
|
self.assertFalse(output[3])
|
||||||
|
|
||||||
|
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
||||||
|
def test_xla_stable_softmax(self):
|
||||||
|
large_penalty = -1e9
|
||||||
|
n_tokens = 10
|
||||||
|
batch_size = 8
|
||||||
|
|
||||||
|
def masked_softmax(x, boolean_mask):
|
||||||
|
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||||
|
masked_x = x + numerical_mask
|
||||||
|
return stable_softmax(masked_x)
|
||||||
|
|
||||||
|
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
|
||||||
|
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
|
||||||
|
x = tf.random.normal((batch_size, n_tokens))
|
||||||
|
|
||||||
|
# Same outcome regardless of the boolean mask here
|
||||||
|
masked_tokens = random.randint(0, n_tokens)
|
||||||
|
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
|
||||||
|
|
||||||
|
# We can randomly mask a random numerical input OUTSIDE XLA
|
||||||
|
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||||
|
masked_x = x + numerical_mask
|
||||||
|
xla_out = xla_stable_softmax(masked_x)
|
||||||
|
out = stable_softmax(masked_x)
|
||||||
|
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||||
|
|
||||||
|
# The stable softmax has the same output as the original softmax
|
||||||
|
unstable_out = tf.nn.softmax(masked_x)
|
||||||
|
assert tf.experimental.numpy.allclose(unstable_out, out)
|
||||||
|
|
||||||
|
# We can randomly mask a random numerical input INSIDE XLA
|
||||||
|
xla_out = xla_masked_softmax(x, boolean_mask)
|
||||||
|
out = masked_softmax(x, boolean_mask)
|
||||||
|
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_from_hub(self):
|
||||||
|
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
|
# the model above is the same as the model below, just a sharded version.
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
def test_sharded_checkpoint_with_prefix(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
|
||||||
|
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
|
||||||
|
for p1, p2 in zip(model.weights, sharded_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
self.assertTrue(p1.name.startswith("a/b/"))
|
||||||
|
self.assertTrue(p2.name.startswith("a/b/"))
|
||||||
|
|
||||||
|
def test_sharded_checkpoint_transfer(self):
|
||||||
|
# If this doesn't throw an error then the test passes
|
||||||
|
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_checkpoint_sharding_local_from_pt(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
|
||||||
|
# the model above is the same as the model below, just a sharded pytorch version.
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_checkpoint_loading_with_prefix_from_pt(self):
|
||||||
|
model = TFBertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
|
||||||
|
)
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
self.assertTrue(p1.name.startswith("a/b/"))
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_checkpoint_sharding_hub_from_pt(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||||
|
# the model above is the same as the model below, just a sharded pytorch version.
|
||||||
|
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
def test_shard_checkpoint(self):
|
||||||
|
# This is the model we will use, total size 340,000 bytes.
|
||||||
|
model = tf.keras.Sequential(
|
||||||
|
[
|
||||||
|
tf.keras.layers.Dense(200, use_bias=False), # size 80,000
|
||||||
|
tf.keras.layers.Dense(200, use_bias=False), # size 160,000
|
||||||
|
tf.keras.layers.Dense(100, use_bias=False), # size 80,000
|
||||||
|
tf.keras.layers.Dense(50, use_bias=False), # size 20,000
|
||||||
|
]
|
||||||
|
)
|
||||||
|
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
||||||
|
model(inputs)
|
||||||
|
weights = model.weights
|
||||||
|
weights_dict = {w.name: w for w in weights}
|
||||||
|
with self.subTest("No shard when max size is bigger than model size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights)
|
||||||
|
self.assertIsNone(index)
|
||||||
|
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
||||||
|
# Split is first two layers then last two.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
||||||
|
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
||||||
|
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
||||||
|
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
||||||
|
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||||
|
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding with weights bigger than max size"):
|
||||||
|
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
||||||
|
# Split is first layer, second layer then last 2.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
||||||
|
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
||||||
|
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
||||||
|
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = [weights_dict["dense/kernel:0"]]
|
||||||
|
shard2 = [weights_dict["dense_1/kernel:0"]]
|
||||||
|
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards,
|
||||||
|
{
|
||||||
|
"tf_model-00001-of-00003.h5": shard1,
|
||||||
|
"tf_model-00002-of-00003.h5": shard2,
|
||||||
|
"tf_model-00003-of-00003.h5": shard3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_special_layer_name_sharding(self):
|
||||||
|
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||||
|
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
||||||
|
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||||
|
assert np.allclose(p1.numpy(), p2.numpy())
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_local(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".h5"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
|
for shard_file, size in shard_to_size.items():
|
||||||
|
if max_size.endswith("kiB"):
|
||||||
|
max_size_int = int(max_size[:-3]) * 2**10
|
||||||
|
else:
|
||||||
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
|
# the size asked for (since we count parameters)
|
||||||
|
if size >= max_size_int + 50000:
|
||||||
|
with h5py.File(shard_file, "r") as state_file:
|
||||||
|
self.assertEqual(len(state_file), 1)
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model.build()
|
||||||
|
new_model.build()
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_save_pretrained_signatures(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Short custom TF signature function.
|
||||||
|
# `input_signature` is specific to BERT.
|
||||||
|
@tf.function(
|
||||||
|
input_signature=[
|
||||||
|
[
|
||||||
|
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
|
||||||
|
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
|
||||||
|
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
|
||||||
|
]
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def serving_fn(input):
|
||||||
|
return model(input)
|
||||||
|
|
||||||
|
# Using default signature (default behavior) overrides 'serving_default'
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
|
||||||
|
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||||
|
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
|
||||||
|
|
||||||
|
# Providing custom signature function
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
|
||||||
|
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||||
|
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
|
||||||
|
|
||||||
|
# Providing multiple custom signature function
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
saved_model=True,
|
||||||
|
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
|
||||||
|
)
|
||||||
|
model_loaded = tf.keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||||
|
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
||||||
|
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
# No tf_model.h5 file, only a model.safetensors
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_safetensors_save_and_load_pt_to_tf(self):
|
||||||
|
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
# Check we have a model.safetensors file
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub(self):
|
||||||
|
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Can load from the TF-formatted checkpoint
|
||||||
|
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
# Can load from the PyTorch-formatted checkpoint
|
||||||
|
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||||
|
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||||
|
|
||||||
|
|
||||||
|
@require_tf
|
||||||
|
@is_staging_test
|
||||||
|
class TFModelPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-model-tf")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = TFBertModel(config)
|
||||||
|
# Make sure model is properly initialized
|
||||||
|
model.build()
|
||||||
|
|
||||||
|
logging.set_verbosity_info()
|
||||||
|
logger = logging.get_logger("transformers.utils.hub")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
model.push_to_hub("test-model-tf", use_auth_token=self._token)
|
||||||
|
logging.set_verbosity_warning()
|
||||||
|
# Check the model card was created and uploaded.
|
||||||
|
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
if not tf.math.reduce_all(p1 == p2):
|
||||||
|
models_equal = False
|
||||||
|
break
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-model-tf")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
if not tf.math.reduce_all(p1 == p2):
|
||||||
|
models_equal = False
|
||||||
|
break
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
@is_pt_tf_cross_test
|
||||||
|
def test_push_to_hub_callback(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = TFBertForMaskedLM(config)
|
||||||
|
model.compile()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
push_to_hub_callback = PushToHubCallback(
|
||||||
|
output_dir=tmp_dir,
|
||||||
|
hub_model_id="test-model-tf-callback",
|
||||||
|
hub_token=self._token,
|
||||||
|
)
|
||||||
|
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
|
||||||
|
|
||||||
|
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
if not tf.math.reduce_all(p1 == p2):
|
||||||
|
models_equal = False
|
||||||
|
break
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
||||||
|
tf_push_to_hub_params.pop("base_model_card_args")
|
||||||
|
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
||||||
|
pt_push_to_hub_params.pop("deprecated_kwargs")
|
||||||
|
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = TFBertModel(config)
|
||||||
|
# Make sure model is properly initialized
|
||||||
|
model.build()
|
||||||
|
|
||||||
|
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
if not tf.math.reduce_all(p1 == p2):
|
||||||
|
models_equal = False
|
||||||
|
break
|
||||||
|
self.assertTrue(models_equal)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-tf-org"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||||
|
models_equal = True
|
||||||
|
for p1, p2 in zip(model.weights, new_model.weights):
|
||||||
|
if not tf.math.reduce_all(p1 == p2):
|
||||||
|
models_equal = False
|
||||||
|
break
|
||||||
|
self.assertTrue(models_equal)
|
||||||
976
tests/test_modeling_utils.py
Executable file
976
tests/test_modeling_utils.py
Executable file
@@ -0,0 +1,976 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import glob
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from huggingface_hub.file_download import http_get
|
||||||
|
from pytest import mark
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoModel,
|
||||||
|
PretrainedConfig,
|
||||||
|
is_torch_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import (
|
||||||
|
TOKEN,
|
||||||
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
|
TestCasePlus,
|
||||||
|
is_staging_test,
|
||||||
|
require_accelerate,
|
||||||
|
require_safetensors,
|
||||||
|
require_torch,
|
||||||
|
require_torch_gpu,
|
||||||
|
require_torch_multi_gpu,
|
||||||
|
require_usr_bin_time,
|
||||||
|
slow,
|
||||||
|
)
|
||||||
|
from transformers.utils import (
|
||||||
|
SAFE_WEIGHTS_INDEX_NAME,
|
||||||
|
SAFE_WEIGHTS_NAME,
|
||||||
|
WEIGHTS_INDEX_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||||
|
AutoModelForCausalLM,
|
||||||
|
AutoTokenizer,
|
||||||
|
BertConfig,
|
||||||
|
BertModel,
|
||||||
|
CLIPTextModel,
|
||||||
|
PreTrainedModel,
|
||||||
|
T5Config,
|
||||||
|
T5ForConditionalGeneration,
|
||||||
|
)
|
||||||
|
from transformers.modeling_utils import shard_checkpoint
|
||||||
|
|
||||||
|
# Fake pretrained models for tests
|
||||||
|
class BaseModel(PreTrainedModel):
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.linear = nn.Linear(4, 5)
|
||||||
|
self.linear_2 = nn.Linear(5, 6)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear_2(self.linear(x))
|
||||||
|
|
||||||
|
class ModelWithHead(PreTrainedModel):
|
||||||
|
base_model_prefix = "base"
|
||||||
|
config_class = PretrainedConfig
|
||||||
|
|
||||||
|
def _init_weights(self, module):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.base = BaseModel(config)
|
||||||
|
# linear is a common name between Base and Head on purpose.
|
||||||
|
self.linear = nn.Linear(6, 3)
|
||||||
|
self.linear2 = nn.Linear(3, 5)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.linear2(self.linear(self.base(x)))
|
||||||
|
|
||||||
|
|
||||||
|
TINY_T5 = "patrickvonplaten/t5-tiny-random"
|
||||||
|
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
|
||||||
|
|
||||||
|
|
||||||
|
def check_models_equal(model1, model2):
|
||||||
|
models_are_equal = True
|
||||||
|
for model1_p, model2_p in zip(model1.parameters(), model2.parameters()):
|
||||||
|
if model1_p.data.ne(model2_p.data).sum() > 0:
|
||||||
|
models_are_equal = False
|
||||||
|
|
||||||
|
return models_are_equal
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
class ModelUtilsTest(TestCasePlus):
|
||||||
|
@slow
|
||||||
|
def test_model_from_pretrained(self):
|
||||||
|
for model_name in BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||||
|
config = BertConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, PretrainedConfig)
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(model_name)
|
||||||
|
model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, PreTrainedModel)
|
||||||
|
|
||||||
|
self.assertEqual(len(loading_info["missing_keys"]), 0)
|
||||||
|
self.assertEqual(len(loading_info["unexpected_keys"]), 8)
|
||||||
|
self.assertEqual(len(loading_info["mismatched_keys"]), 0)
|
||||||
|
self.assertEqual(len(loading_info["error_msgs"]), 0)
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||||
|
|
||||||
|
# Not sure this is the intended behavior. TODO fix Lysandre & Thom
|
||||||
|
config.name_or_path = model_name
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)
|
||||||
|
self.assertEqual(model.config.output_hidden_states, True)
|
||||||
|
self.assertEqual(model.config, config)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_subfolder(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
model = BertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_subfolder_sharded(self):
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
model = BertModel(config)
|
||||||
|
|
||||||
|
subfolder = "bert"
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||||
|
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
model_loaded = BertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertTrue(check_models_equal(model, model_loaded))
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||||
|
subfolder = "bert"
|
||||||
|
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||||
|
with self.assertRaises(OSError):
|
||||||
|
_ = BertModel.from_pretrained(model_id)
|
||||||
|
|
||||||
|
model = BertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_with_different_pretrained_model_name(self):
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
logger = logging.get_logger("transformers.configuration_utils")
|
||||||
|
with CaptureLogger(logger) as cl:
|
||||||
|
BertModel.from_pretrained(TINY_T5)
|
||||||
|
self.assertTrue("You are using a model of type t5 to instantiate a model of type bert" in cl.out)
|
||||||
|
|
||||||
|
def test_model_from_config_torch_dtype(self):
|
||||||
|
# test that the model can be instantiated with dtype of user's choice - as long as it's a
|
||||||
|
# float dtype. To make it happen config.torch_dtype needs to be set before instantiating the
|
||||||
|
# model from the config object.
|
||||||
|
|
||||||
|
config = T5Config.from_pretrained(TINY_T5)
|
||||||
|
model = AutoModel.from_config(config)
|
||||||
|
# XXX: isn't supported
|
||||||
|
# model = T5ForConditionalGeneration.from_config(config)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
model = AutoModel.from_config(config, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
model = AutoModel.from_config(config, torch_dtype=torch.int64)
|
||||||
|
|
||||||
|
def test_model_from_pretrained_torch_dtype(self):
|
||||||
|
# test that the model can be instantiated with dtype of either
|
||||||
|
# 1. explicit from_pretrained's torch_dtype argument
|
||||||
|
# 2. via autodiscovery by looking at model weights (torch_dtype="auto")
|
||||||
|
# so if a model.half() was saved, we want it to be instantiated as such.
|
||||||
|
#
|
||||||
|
# test an explicit model class, but also AutoModel separately as the latter goes through a different code path
|
||||||
|
model_path = self.get_auto_remove_tmp_dir()
|
||||||
|
|
||||||
|
# baseline - we know TINY_T5 is fp32 model
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(TINY_T5)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
def remove_torch_dtype(model_path):
|
||||||
|
file = f"{model_path}/config.json"
|
||||||
|
with open(file, "r", encoding="utf-8") as f:
|
||||||
|
s = json.load(f)
|
||||||
|
s.pop("torch_dtype")
|
||||||
|
with open(file, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(s, f)
|
||||||
|
|
||||||
|
# test the default fp32 save_pretrained => from_pretrained cycle
|
||||||
|
model.save_pretrained(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path)
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
# 2. test torch_dtype="auto" via auto-derivation
|
||||||
|
# now remove the torch_dtype entry from config.json and try "auto" again which should
|
||||||
|
# perform auto-derivation from weights
|
||||||
|
remove_torch_dtype(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
# test forced loading in fp16 (even though the weights are in fp32)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test fp16 save_pretrained, loaded with auto-detection
|
||||||
|
model = model.half()
|
||||||
|
model.save_pretrained(model_path)
|
||||||
|
# 1. test torch_dtype="auto" via `config.torch_dtype`
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.config.torch_dtype, torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
# tests `config.torch_dtype` saving
|
||||||
|
with open(f"{model_path}/config.json") as f:
|
||||||
|
config_dict = json.load(f)
|
||||||
|
self.assertEqual(config_dict["torch_dtype"], "float16")
|
||||||
|
# 2. test torch_dtype="auto" via auto-derivation
|
||||||
|
# now same with using config info
|
||||||
|
remove_torch_dtype(model_path)
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# 3. now retest that AutoModel behaves the same wrt torch_dtype="auto" as T5ForConditionalGeneration
|
||||||
|
model = AutoModel.from_pretrained(model_path, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test fp16 save_pretrained, loaded with the explicit fp16
|
||||||
|
model = T5ForConditionalGeneration.from_pretrained(model_path, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test AutoModel separately as it goes through a different path
|
||||||
|
# test auto-detection - as currently TINY_T5 doesn't have torch_dtype entry
|
||||||
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="auto")
|
||||||
|
# test that the config object didn't get polluted with torch_dtype="auto"
|
||||||
|
# there was a bug that after this call we ended up with config.torch_dtype=="auto"
|
||||||
|
self.assertNotEqual(model.config.torch_dtype, "auto")
|
||||||
|
# now test the outcome
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||||
|
self.assertEqual(model.dtype, torch.float16)
|
||||||
|
|
||||||
|
# test model whose first param is not of a floating type, but int
|
||||||
|
model = AutoModel.from_pretrained(TINY_BERT_FOR_TOKEN_CLASSIFICATION, torch_dtype="auto")
|
||||||
|
self.assertEqual(model.dtype, torch.float32)
|
||||||
|
|
||||||
|
def test_no_super_init_config_and_model(self):
|
||||||
|
config = NoSuperInitConfig(attribute=32)
|
||||||
|
model = NoSuperInitModel(config)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_shard_checkpoint(self):
|
||||||
|
# This is the model we will use, total size 340,000 bytes.
|
||||||
|
model = torch.nn.Sequential(
|
||||||
|
torch.nn.Linear(100, 200, bias=False), # size 80,000
|
||||||
|
torch.nn.Linear(200, 200, bias=False), # size 160,000
|
||||||
|
torch.nn.Linear(200, 100, bias=False), # size 80,000
|
||||||
|
torch.nn.Linear(100, 50, bias=False), # size 20,000
|
||||||
|
)
|
||||||
|
state_dict = model.state_dict()
|
||||||
|
|
||||||
|
with self.subTest("No shard when max size is bigger than model size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict)
|
||||||
|
self.assertIsNone(index)
|
||||||
|
self.assertDictEqual(shards, {WEIGHTS_NAME: state_dict})
|
||||||
|
|
||||||
|
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size="300kB")
|
||||||
|
# Split is first two layers then last two.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"0.weight": "pytorch_model-00001-of-00002.bin",
|
||||||
|
"1.weight": "pytorch_model-00001-of-00002.bin",
|
||||||
|
"2.weight": "pytorch_model-00002-of-00002.bin",
|
||||||
|
"3.weight": "pytorch_model-00002-of-00002.bin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = {"0.weight": state_dict["0.weight"], "1.weight": state_dict["1.weight"]}
|
||||||
|
shard2 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards, {"pytorch_model-00001-of-00002.bin": shard1, "pytorch_model-00002-of-00002.bin": shard2}
|
||||||
|
)
|
||||||
|
|
||||||
|
with self.subTest("Test sharding with weights bigger than max size"):
|
||||||
|
shards, index = shard_checkpoint(state_dict, max_shard_size="100kB")
|
||||||
|
# Split is first layer, second layer then last 2.
|
||||||
|
self.assertDictEqual(
|
||||||
|
index,
|
||||||
|
{
|
||||||
|
"metadata": {"total_size": 340000},
|
||||||
|
"weight_map": {
|
||||||
|
"0.weight": "pytorch_model-00001-of-00003.bin",
|
||||||
|
"1.weight": "pytorch_model-00002-of-00003.bin",
|
||||||
|
"2.weight": "pytorch_model-00003-of-00003.bin",
|
||||||
|
"3.weight": "pytorch_model-00003-of-00003.bin",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
shard1 = {"0.weight": state_dict["0.weight"]}
|
||||||
|
shard2 = {"1.weight": state_dict["1.weight"]}
|
||||||
|
shard3 = {"2.weight": state_dict["2.weight"], "3.weight": state_dict["3.weight"]}
|
||||||
|
self.assertDictEqual(
|
||||||
|
shards,
|
||||||
|
{
|
||||||
|
"pytorch_model-00001-of-00003.bin": shard1,
|
||||||
|
"pytorch_model-00002-of-00003.bin": shard2,
|
||||||
|
"pytorch_model-00003-of-00003.bin": shard3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_local(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||||
|
for max_size in ["50kB", "50kiB", "100kB", "100kiB", "200kB", "200kiB"]:
|
||||||
|
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||||
|
|
||||||
|
# Get each shard file and its size
|
||||||
|
shard_to_size = {}
|
||||||
|
for shard in os.listdir(tmp_dir):
|
||||||
|
if shard.endswith(".bin"):
|
||||||
|
shard_file = os.path.join(tmp_dir, shard)
|
||||||
|
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||||
|
|
||||||
|
index_file = os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)
|
||||||
|
# Check there is an index but no regular weight file
|
||||||
|
self.assertTrue(os.path.isfile(index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
# Check a file is bigger than max_size only when it has a single weight
|
||||||
|
for shard_file, size in shard_to_size.items():
|
||||||
|
if max_size.endswith("kiB"):
|
||||||
|
max_size_int = int(max_size[:-3]) * 2**10
|
||||||
|
else:
|
||||||
|
max_size_int = int(max_size[:-2]) * 10**3
|
||||||
|
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||||
|
# the size asked for (since we count parameters)
|
||||||
|
if size >= max_size_int + 50000:
|
||||||
|
state_dict = torch.load(shard_file)
|
||||||
|
self.assertEqual(len(state_dict), 1)
|
||||||
|
|
||||||
|
# Check the index and the shard files found match
|
||||||
|
with open(index_file, "r", encoding="utf-8") as f:
|
||||||
|
index = json.loads(f.read())
|
||||||
|
|
||||||
|
all_shards = set(index["weight_map"].values())
|
||||||
|
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".bin")}
|
||||||
|
self.assertSetEqual(all_shards, shards_found)
|
||||||
|
|
||||||
|
# Finally, check the model can be reloaded
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_sharding_from_hub(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
# the model above is the same as the model below, just a sharded version.
|
||||||
|
ref_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
for p1, p2 in zip(model.parameters(), ref_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_local(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
|
||||||
|
weights_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_local_sharded(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB")
|
||||||
|
|
||||||
|
weights_index_name = ".".join(WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["bin"])
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_local_safe(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", safe_serialization=True)
|
||||||
|
|
||||||
|
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["safetensors"])
|
||||||
|
|
||||||
|
weights_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_local_sharded_safe(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2", max_shard_size="50kB", safe_serialization=True)
|
||||||
|
|
||||||
|
weights_index_name = ".".join(SAFE_WEIGHTS_INDEX_NAME.split(".")[:-1] + ["v2"] + ["json"])
|
||||||
|
weights_index_file = os.path.join(tmp_dir, weights_index_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_index_file))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
|
||||||
|
for i in range(1, 6):
|
||||||
|
weights_name = ".".join(SAFE_WEIGHTS_NAME.split(".")[:-1] + [f"v2-0000{i}-of-00006"] + ["safetensors"])
|
||||||
|
weights_name_file = os.path.join(tmp_dir, weights_name)
|
||||||
|
self.assertTrue(os.path.isfile(weights_name_file))
|
||||||
|
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir, variant="v2")
|
||||||
|
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_checkpoint_variant_hub(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_checkpoint_variant_hub_sharded(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir
|
||||||
|
)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_hub_safe(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-safe", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_checkpoint_variant_hub_sharded_safe(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
with self.assertRaises(EnvironmentError):
|
||||||
|
_ = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir
|
||||||
|
)
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant-sharded-safe", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
def test_checkpoint_variant_save_load(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model = BertModel.from_pretrained(
|
||||||
|
"hf-internal-testing/tiny-random-bert-variant", cache_dir=tmp_dir, variant="v2"
|
||||||
|
)
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_dir, variant="v2")
|
||||||
|
# saving will create a variant checkpoint
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
# saving shouldn't delete variant checkpoints
|
||||||
|
weights_name = ".".join(WEIGHTS_NAME.split(".")[:-1] + ["v2"] + ["bin"])
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, weights_name)))
|
||||||
|
|
||||||
|
# there should be a normal checkpoint
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
def test_from_pretrained_low_cpu_mem_usage_functional(self):
|
||||||
|
# test that we can use `from_pretrained(..., low_cpu_mem_usage=True)` with normal and
|
||||||
|
# sharded models
|
||||||
|
|
||||||
|
mnames = [
|
||||||
|
"hf-internal-testing/tiny-random-bert-sharded",
|
||||||
|
"hf-internal-testing/tiny-random-bert",
|
||||||
|
]
|
||||||
|
for mname in mnames:
|
||||||
|
_ = BertModel.from_pretrained(mname, low_cpu_mem_usage=True)
|
||||||
|
|
||||||
|
@require_usr_bin_time
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
def test_from_pretrained_low_cpu_mem_usage_measured(self):
|
||||||
|
# test that `from_pretrained(..., low_cpu_mem_usage=True)` uses less cpu memory than default
|
||||||
|
|
||||||
|
mname = "bert-base-cased"
|
||||||
|
|
||||||
|
preamble = "from transformers import AutoModel"
|
||||||
|
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=False)'
|
||||||
|
max_rss_normal = self.python_one_liner_max_rss(one_liner_str)
|
||||||
|
# print(f"{max_rss_normal=}")
|
||||||
|
|
||||||
|
one_liner_str = f'{preamble}; AutoModel.from_pretrained("{mname}", low_cpu_mem_usage=True)'
|
||||||
|
max_rss_low_mem = self.python_one_liner_max_rss(one_liner_str)
|
||||||
|
# print(f"{max_rss_low_mem=}")
|
||||||
|
|
||||||
|
diff_bytes = max_rss_normal - max_rss_low_mem
|
||||||
|
diff_percent = diff_bytes / max_rss_low_mem
|
||||||
|
# print(f"{diff_bytes=}, {diff_percent=}")
|
||||||
|
# ideally we would compare that the diff is close to ~1x checkpoint size in bytes, but
|
||||||
|
# measuring cpu memory on linux is very tricky and inconsistent, so instead let's check that
|
||||||
|
# it's at least 15% less cpu memory consumed
|
||||||
|
|
||||||
|
self.assertGreater(
|
||||||
|
diff_percent,
|
||||||
|
0.15,
|
||||||
|
"should use less CPU memory for low_cpu_mem_usage=True, "
|
||||||
|
f"but got max_rss_normal={max_rss_normal} and max_rss_low_mem={max_rss_low_mem}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# if you want to compare things manually, let's first look at the size of the model in bytes
|
||||||
|
# model = BertModel.from_pretrained(mname, low_cpu_mem_usage=False)
|
||||||
|
# total_numel = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
|
||||||
|
# total_bytes = total_numel * 4 # 420MB
|
||||||
|
# Now the diff_bytes should be very close to total_bytes, but the reports are inconsistent.
|
||||||
|
# The easiest way to test this is to switch the model and torch.load to do all the work on
|
||||||
|
# gpu - that way one can measure exactly the total and peak memory used. Perhaps once we add
|
||||||
|
# functionality to load models directly on gpu, this test can be rewritten to use torch's
|
||||||
|
# cuda memory tracking and then we should be able to do a much more precise test.
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
@require_torch_multi_gpu
|
||||||
|
@slow
|
||||||
|
def test_model_parallelism_gpt2(self):
|
||||||
|
device_map = {"transformer.wte": 0, "transformer.wpe": 0, "lm_head": 0, "transformer.ln_f": 1}
|
||||||
|
for i in range(12):
|
||||||
|
device_map[f"transformer.h.{i}"] = 0 if i <= 5 else 1
|
||||||
|
|
||||||
|
model = AutoModelForCausalLM.from_pretrained("gpt2", device_map=device_map)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
||||||
|
inputs = tokenizer("Hello, my name is", return_tensors="pt")
|
||||||
|
output = model.generate(inputs["input_ids"].to(0))
|
||||||
|
|
||||||
|
text_output = tokenizer.decode(output[0].tolist())
|
||||||
|
self.assertEqual(text_output, "Hello, my name is John. I'm a writer, and I'm a writer. I'm")
|
||||||
|
|
||||||
|
@require_accelerate
|
||||||
|
@mark.accelerate_tests
|
||||||
|
@require_torch_gpu
|
||||||
|
def test_from_pretrained_disk_offload_task_model(self):
|
||||||
|
model = AutoModel.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
device_map = {
|
||||||
|
"transformer.wte": 0,
|
||||||
|
"transformer.wpe": 0,
|
||||||
|
"transformer.h.0": "cpu",
|
||||||
|
"transformer.h.1": "cpu",
|
||||||
|
"transformer.h.2": "cpu",
|
||||||
|
"transformer.h.3": "disk",
|
||||||
|
"transformer.h.4": "disk",
|
||||||
|
"transformer.ln_f": 0,
|
||||||
|
"lm_head": 0,
|
||||||
|
}
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
inputs = torch.tensor([[1, 2, 3]]).to(0)
|
||||||
|
|
||||||
|
model.save_pretrained(tmp_dir)
|
||||||
|
new_model = AutoModelForCausalLM.from_pretrained(tmp_dir).to(0)
|
||||||
|
outputs1 = new_model.to(0)(inputs)
|
||||||
|
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
|
||||||
|
tmp_dir, device_map=device_map, offload_folder=offload_folder
|
||||||
|
)
|
||||||
|
outputs2 = new_model_with_offload(inputs)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||||
|
|
||||||
|
# With state dict temp offload
|
||||||
|
offload_folder = os.path.join(tmp_dir, "offload")
|
||||||
|
new_model_with_offload = AutoModelForCausalLM.from_pretrained(
|
||||||
|
tmp_dir,
|
||||||
|
device_map=device_map,
|
||||||
|
offload_folder=offload_folder,
|
||||||
|
offload_state_dict=True,
|
||||||
|
)
|
||||||
|
outputs2 = new_model_with_offload(inputs)
|
||||||
|
|
||||||
|
self.assertTrue(torch.allclose(outputs1.logits.cpu(), outputs2.logits.cpu()))
|
||||||
|
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_load_from_one_file(self):
|
||||||
|
try:
|
||||||
|
tmp_file = tempfile.mktemp()
|
||||||
|
with open(tmp_file, "wb") as f:
|
||||||
|
http_get(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", f
|
||||||
|
)
|
||||||
|
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
_ = BertModel.from_pretrained(tmp_file, config=config)
|
||||||
|
finally:
|
||||||
|
os.remove(tmp_file)
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
_ = BertModel.from_pretrained(
|
||||||
|
"https://huggingface.co/hf-internal-testing/tiny-random-bert/resolve/main/pytorch_model.bin", config=config
|
||||||
|
)
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_use_safetensors(self):
|
||||||
|
# test nice error message if no safetensor files available
|
||||||
|
with self.assertRaises(OSError) as env_error:
|
||||||
|
AutoModel.from_pretrained("hf-internal-testing/tiny-random-RobertaModel", use_safetensors=True)
|
||||||
|
|
||||||
|
self.assertTrue(
|
||||||
|
"model.safetensors or model.safetensors.index.json and thus cannot be loaded with `safetensors`"
|
||||||
|
in str(env_error.exception)
|
||||||
|
)
|
||||||
|
|
||||||
|
# test that error if only safetensors is available
|
||||||
|
with self.assertRaises(OSError) as env_error:
|
||||||
|
BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors", use_safetensors=False)
|
||||||
|
|
||||||
|
self.assertTrue("does not appear to have a file named pytorch_model.bin" in str(env_error.exception))
|
||||||
|
|
||||||
|
# test that only safetensors if both available and use_safetensors=False
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
CLIPTextModel.from_pretrained(
|
||||||
|
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||||
|
subfolder="text_encoder",
|
||||||
|
use_safetensors=False,
|
||||||
|
cache_dir=tmp_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||||
|
self.assertTrue(any(f.endswith("bin") for f in all_downloaded_files))
|
||||||
|
self.assertFalse(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||||
|
|
||||||
|
# test that no safetensors if both available and use_safetensors=True
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
CLIPTextModel.from_pretrained(
|
||||||
|
"hf-internal-testing/diffusers-stable-diffusion-tiny-all",
|
||||||
|
subfolder="text_encoder",
|
||||||
|
use_safetensors=True,
|
||||||
|
cache_dir=tmp_dir,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_downloaded_files = glob.glob(os.path.join(tmp_dir, "*", "snapshots", "*", "*", "*"))
|
||||||
|
self.assertTrue(any(f.endswith("safetensors") for f in all_downloaded_files))
|
||||||
|
self.assertFalse(any(f.endswith("bin") for f in all_downloaded_files))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||||
|
# No pytorch_model.bin file, only a model.safetensors
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub(self):
|
||||||
|
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||||
|
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_save_and_load_sharded(self):
|
||||||
|
model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="100kB")
|
||||||
|
# No pytorch_model.bin index file, only a model.safetensors index
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_INDEX_NAME)))
|
||||||
|
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||||
|
# No regular weights file
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, WEIGHTS_NAME)))
|
||||||
|
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
@require_safetensors
|
||||||
|
def test_safetensors_load_from_hub_sharded(self):
|
||||||
|
safetensors_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded-safetensors")
|
||||||
|
pytorch_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded")
|
||||||
|
|
||||||
|
# Check models are equal
|
||||||
|
for p1, p2 in zip(safetensors_model.parameters(), pytorch_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
def test_base_model_to_head_model_load(self):
|
||||||
|
base_model = BaseModel(PretrainedConfig())
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
base_model.save_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
# Can load a base model in a model with head
|
||||||
|
model = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
for p1, p2 in zip(model.base.parameters(), base_model.parameters()):
|
||||||
|
self.assertTrue(torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
# It doesn't work if the state dict has a mix of keys of the head and base without prefix though.
|
||||||
|
base_state_dict = base_model.state_dict()
|
||||||
|
head_state_dict = model.state_dict()
|
||||||
|
base_state_dict["linear2.weight"] = head_state_dict["linear2.weight"]
|
||||||
|
base_state_dict["linear2.bias"] = head_state_dict["linear2.bias"]
|
||||||
|
torch.save(base_state_dict, os.path.join(tmp_dir, WEIGHTS_NAME))
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
ValueError, "The state dictionary of the model you are trying to load is corrupted."
|
||||||
|
):
|
||||||
|
_ = ModelWithHead.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@slow
|
||||||
|
def test_pretrained_low_mem_new_config(self):
|
||||||
|
# Checking for 1 model(the same one which was described in the issue) .
|
||||||
|
model_ids = ["gpt2"]
|
||||||
|
|
||||||
|
for model_id in model_ids:
|
||||||
|
model_config = AutoConfig.from_pretrained(pretrained_model_name_or_path=model_id)
|
||||||
|
model_config.n_layer = 48
|
||||||
|
model_config.n_head = 25
|
||||||
|
model_config.n_embd = 1600
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
pretrained_model_name_or_path=model_id,
|
||||||
|
config=model_config,
|
||||||
|
ignore_mismatched_sizes=True,
|
||||||
|
torch_dtype=torch.float16,
|
||||||
|
low_cpu_mem_usage=True,
|
||||||
|
)
|
||||||
|
model_ref = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=model_id)
|
||||||
|
|
||||||
|
self.assertEqual(model.__class__.__name__, model_ref.__class__.__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch
|
||||||
|
@is_staging_test
|
||||||
|
class ModelPushToHubTester(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-model")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-model-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-model")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = BertModel(config)
|
||||||
|
model.push_to_hub("test-model", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-model")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(tmp_dir, repo_id="test-model", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained(f"{USER}/test-model")
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
config = BertConfig(
|
||||||
|
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||||
|
)
|
||||||
|
model = BertModel(config)
|
||||||
|
model.push_to_hub("valid_org/test-model-org", use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-model-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
model.save_pretrained(
|
||||||
|
tmp_dir, push_to_hub=True, use_auth_token=self._token, repo_id="valid_org/test-model-org"
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = BertModel.from_pretrained("valid_org/test-model-org")
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
def test_push_to_hub_dynamic_model(self):
|
||||||
|
CustomConfig.register_for_auto_class()
|
||||||
|
CustomModel.register_for_auto_class()
|
||||||
|
|
||||||
|
config = CustomConfig(hidden_size=32)
|
||||||
|
model = CustomModel(config)
|
||||||
|
|
||||||
|
model.push_to_hub("test-dynamic-model", use_auth_token=self._token)
|
||||||
|
# checks
|
||||||
|
self.assertDictEqual(
|
||||||
|
config.auto_map,
|
||||||
|
{"AutoConfig": "custom_configuration.CustomConfig", "AutoModel": "custom_modeling.CustomModel"},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_model = AutoModel.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||||
|
# Can't make an isinstance check because the new_model is from the CustomModel class of a dynamic module
|
||||||
|
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||||
|
for p1, p2 in zip(model.parameters(), new_model.parameters()):
|
||||||
|
self.assertTrue(torch.equal(p1, p2))
|
||||||
|
|
||||||
|
config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-model", trust_remote_code=True)
|
||||||
|
new_model = AutoModel.from_config(config, trust_remote_code=True)
|
||||||
|
self.assertEqual(new_model.__class__.__name__, "CustomModel")
|
||||||
@@ -21,28 +21,20 @@ import os
|
|||||||
import pickle
|
import pickle
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import traceback
|
import traceback
|
||||||
import unittest
|
import unittest
|
||||||
import unittest.mock as mock
|
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import takewhile
|
from itertools import takewhile
|
||||||
from pathlib import Path
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
from huggingface_hub import HfFolder, delete_repo
|
|
||||||
from huggingface_hub.file_download import http_get
|
|
||||||
from parameterized import parameterized
|
from parameterized import parameterized
|
||||||
from requests.exceptions import HTTPError
|
|
||||||
|
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AlbertTokenizer,
|
AlbertTokenizer,
|
||||||
AlbertTokenizerFast,
|
AlbertTokenizerFast,
|
||||||
AutoTokenizer,
|
|
||||||
BertTokenizer,
|
BertTokenizer,
|
||||||
BertTokenizerFast,
|
BertTokenizerFast,
|
||||||
GPT2TokenizerFast,
|
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
PreTrainedTokenizerFast,
|
PreTrainedTokenizerFast,
|
||||||
@@ -51,24 +43,20 @@ from transformers import (
|
|||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
is_flax_available,
|
is_flax_available,
|
||||||
is_tf_available,
|
is_tf_available,
|
||||||
is_tokenizers_available,
|
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
logging,
|
logging,
|
||||||
)
|
)
|
||||||
from transformers.testing_utils import (
|
from transformers.testing_utils import (
|
||||||
TOKEN,
|
|
||||||
USER,
|
|
||||||
check_json_file_has_correct_format,
|
check_json_file_has_correct_format,
|
||||||
get_tests_dir,
|
get_tests_dir,
|
||||||
is_pt_tf_cross_test,
|
is_pt_tf_cross_test,
|
||||||
is_staging_test,
|
|
||||||
require_tf,
|
require_tf,
|
||||||
require_tokenizers,
|
require_tokenizers,
|
||||||
require_torch,
|
require_torch,
|
||||||
run_test_in_subprocess,
|
run_test_in_subprocess,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
from transformers.tokenization_utils import AddedToken, Trie
|
from transformers.tokenization_utils import AddedToken
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@@ -79,15 +67,6 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
|
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
|
||||||
|
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
|
||||||
|
|
||||||
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
|
||||||
|
|
||||||
|
|
||||||
if is_tokenizers_available():
|
|
||||||
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
NON_ENGLISH_TAGS = ["chinese", "dutch", "french", "finnish", "german", "multilingual"]
|
||||||
@@ -3974,238 +3953,3 @@ class TokenizerTesterMixin:
|
|||||||
tokenizer.clean_up_tokenization_spaces = True
|
tokenizer.clean_up_tokenization_spaces = True
|
||||||
decoded = tokenizer.decode(tokens)
|
decoded = tokenizer.decode(tokens)
|
||||||
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
|
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
|
||||||
|
|
||||||
|
|
||||||
class TokenizerUtilTester(unittest.TestCase):
|
|
||||||
def test_cached_files_are_used_when_internet_is_down(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
@require_tokenizers
|
|
||||||
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
|
|
||||||
# A mock response for an HTTP head request to emulate server down
|
|
||||||
response_mock = mock.Mock()
|
|
||||||
response_mock.status_code = 500
|
|
||||||
response_mock.headers = {}
|
|
||||||
response_mock.raise_for_status.side_effect = HTTPError
|
|
||||||
response_mock.json.return_value = {}
|
|
||||||
|
|
||||||
# Download this model to make sure it's in the cache.
|
|
||||||
_ = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
|
|
||||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
|
||||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
|
||||||
_ = GPT2TokenizerFast.from_pretrained("gpt2")
|
|
||||||
# This check we did call the fake head request
|
|
||||||
mock_head.assert_called()
|
|
||||||
|
|
||||||
def test_legacy_load_from_one_file(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
try:
|
|
||||||
tmp_file = tempfile.mktemp()
|
|
||||||
with open(tmp_file, "wb") as f:
|
|
||||||
http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f)
|
|
||||||
|
|
||||||
_ = AlbertTokenizer.from_pretrained(tmp_file)
|
|
||||||
finally:
|
|
||||||
os.remove(tmp_file)
|
|
||||||
|
|
||||||
# Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
|
|
||||||
# the current folder and have the right name.
|
|
||||||
if os.path.isfile("tokenizer.json"):
|
|
||||||
# We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
with open("tokenizer.json", "wb") as f:
|
|
||||||
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
|
||||||
# The tiny random BERT has a vocab size of 1024, tiny gpt2 as a vocab size of 1000
|
|
||||||
self.assertEqual(tokenizer.vocab_size, 1000)
|
|
||||||
# Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
|
|
||||||
|
|
||||||
finally:
|
|
||||||
os.remove("tokenizer.json")
|
|
||||||
|
|
||||||
def test_legacy_load_from_url(self):
|
|
||||||
# This test is for deprecated behavior and can be removed in v5
|
|
||||||
_ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert-base-v1/resolve/main/spiece.model")
|
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
|
||||||
class TokenizerPushToHubTester(unittest.TestCase):
|
|
||||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def setUpClass(cls):
|
|
||||||
cls._token = TOKEN
|
|
||||||
HfFolder.save_token(TOKEN)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def tearDownClass(cls):
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-tokenizer")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
try:
|
|
||||||
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
|
|
||||||
except HTTPError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def test_push_to_hub(self):
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
|
||||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
|
||||||
|
|
||||||
tokenizer.push_to_hub("test-tokenizer", use_auth_token=self._token)
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
|
||||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="test-tokenizer")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, use_auth_token=self._token)
|
|
||||||
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
|
||||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
|
||||||
|
|
||||||
def test_push_to_hub_in_organization(self):
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
|
||||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
|
||||||
tokenizer = BertTokenizer(vocab_file)
|
|
||||||
|
|
||||||
tokenizer.push_to_hub("valid_org/test-tokenizer-org", use_auth_token=self._token)
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
|
||||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
|
||||||
|
|
||||||
# Reset repo
|
|
||||||
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
|
|
||||||
|
|
||||||
# Push to hub via save_pretrained
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
tokenizer.save_pretrained(
|
|
||||||
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, use_auth_token=self._token
|
|
||||||
)
|
|
||||||
|
|
||||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
|
||||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
|
||||||
|
|
||||||
@require_tokenizers
|
|
||||||
def test_push_to_hub_dynamic_tokenizer(self):
|
|
||||||
CustomTokenizer.register_for_auto_class()
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
|
||||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
|
||||||
tokenizer = CustomTokenizer(vocab_file)
|
|
||||||
|
|
||||||
# No fast custom tokenizer
|
|
||||||
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
|
||||||
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
|
|
||||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
|
||||||
|
|
||||||
# Fast and slow custom tokenizer
|
|
||||||
CustomTokenizerFast.register_for_auto_class()
|
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
||||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
|
||||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
|
||||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
|
||||||
|
|
||||||
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
|
|
||||||
bert_tokenizer.save_pretrained(tmp_dir)
|
|
||||||
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
|
||||||
|
|
||||||
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
|
||||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
|
||||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
|
||||||
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
|
|
||||||
)
|
|
||||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
|
||||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
|
||||||
|
|
||||||
|
|
||||||
class TrieTest(unittest.TestCase):
|
|
||||||
def test_trie(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("Hello 友達")
|
|
||||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
|
||||||
trie.add("Hello")
|
|
||||||
trie.data
|
|
||||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
|
||||||
|
|
||||||
def test_trie_split(self):
|
|
||||||
trie = Trie()
|
|
||||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
|
||||||
trie.add("[CLS]")
|
|
||||||
trie.add("extra_id_1")
|
|
||||||
trie.add("extra_id_100")
|
|
||||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
|
||||||
|
|
||||||
def test_trie_single(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("A")
|
|
||||||
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
|
||||||
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
|
||||||
|
|
||||||
def test_trie_final(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("TOKEN]")
|
|
||||||
trie.add("[SPECIAL_TOKEN]")
|
|
||||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
|
||||||
|
|
||||||
def test_trie_subtokens(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("A")
|
|
||||||
trie.add("P")
|
|
||||||
trie.add("[SPECIAL_TOKEN]")
|
|
||||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
|
||||||
|
|
||||||
def test_trie_suffix_tokens(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("AB")
|
|
||||||
trie.add("B")
|
|
||||||
trie.add("C")
|
|
||||||
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
|
||||||
|
|
||||||
def test_trie_skip(self):
|
|
||||||
trie = Trie()
|
|
||||||
trie.add("ABC")
|
|
||||||
trie.add("B")
|
|
||||||
trie.add("CD")
|
|
||||||
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
|
|
||||||
|
|
||||||
def test_cut_text_hardening(self):
|
|
||||||
# Even if the offsets are wrong, we necessarily output correct string
|
|
||||||
# parts.
|
|
||||||
trie = Trie()
|
|
||||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
|
||||||
self.assertEqual(parts, ["AB", "C"])
|
|
||||||
|
|||||||
280
tests/test_tokenization_utils.py
Normal file
280
tests/test_tokenization_utils.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2019 HuggingFace Inc.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
import unittest.mock as mock
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from huggingface_hub import HfFolder, delete_repo
|
||||||
|
from huggingface_hub.file_download import http_get
|
||||||
|
from requests.exceptions import HTTPError
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AlbertTokenizer,
|
||||||
|
AutoTokenizer,
|
||||||
|
BertTokenizer,
|
||||||
|
BertTokenizerFast,
|
||||||
|
GPT2TokenizerFast,
|
||||||
|
is_tokenizers_available,
|
||||||
|
)
|
||||||
|
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
|
||||||
|
from transformers.tokenization_utils import Trie
|
||||||
|
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||||
|
|
||||||
|
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||||
|
|
||||||
|
|
||||||
|
if is_tokenizers_available():
|
||||||
|
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
||||||
|
|
||||||
|
|
||||||
|
class TokenizerUtilTester(unittest.TestCase):
|
||||||
|
def test_cached_files_are_used_when_internet_is_down(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
|
||||||
|
# A mock response for an HTTP head request to emulate server down
|
||||||
|
response_mock = mock.Mock()
|
||||||
|
response_mock.status_code = 500
|
||||||
|
response_mock.headers = {}
|
||||||
|
response_mock.raise_for_status.side_effect = HTTPError
|
||||||
|
response_mock.json.return_value = {}
|
||||||
|
|
||||||
|
# Download this model to make sure it's in the cache.
|
||||||
|
_ = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
|
||||||
|
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||||
|
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||||
|
_ = GPT2TokenizerFast.from_pretrained("gpt2")
|
||||||
|
# This check we did call the fake head request
|
||||||
|
mock_head.assert_called()
|
||||||
|
|
||||||
|
def test_legacy_load_from_one_file(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
try:
|
||||||
|
tmp_file = tempfile.mktemp()
|
||||||
|
with open(tmp_file, "wb") as f:
|
||||||
|
http_get("https://huggingface.co/albert-base-v1/resolve/main/spiece.model", f)
|
||||||
|
|
||||||
|
_ = AlbertTokenizer.from_pretrained(tmp_file)
|
||||||
|
finally:
|
||||||
|
os.remove(tmp_file)
|
||||||
|
|
||||||
|
# Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
|
||||||
|
# the current folder and have the right name.
|
||||||
|
if os.path.isfile("tokenizer.json"):
|
||||||
|
# We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
with open("tokenizer.json", "wb") as f:
|
||||||
|
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||||
|
# The tiny random BERT has a vocab size of 1024, tiny gpt2 as a vocab size of 1000
|
||||||
|
self.assertEqual(tokenizer.vocab_size, 1000)
|
||||||
|
# Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
|
||||||
|
|
||||||
|
finally:
|
||||||
|
os.remove("tokenizer.json")
|
||||||
|
|
||||||
|
def test_legacy_load_from_url(self):
|
||||||
|
# This test is for deprecated behavior and can be removed in v5
|
||||||
|
_ = AlbertTokenizer.from_pretrained("https://huggingface.co/albert-base-v1/resolve/main/spiece.model")
|
||||||
|
|
||||||
|
|
||||||
|
@is_staging_test
|
||||||
|
class TokenizerPushToHubTester(unittest.TestCase):
|
||||||
|
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls._token = TOKEN
|
||||||
|
HfFolder.save_token(TOKEN)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-tokenizer")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
|
||||||
|
except HTTPError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def test_push_to_hub(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
|
tokenizer.push_to_hub("test-tokenizer", use_auth_token=self._token)
|
||||||
|
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||||
|
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="test-tokenizer")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, use_auth_token=self._token)
|
||||||
|
|
||||||
|
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||||
|
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||||
|
|
||||||
|
def test_push_to_hub_in_organization(self):
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
tokenizer = BertTokenizer(vocab_file)
|
||||||
|
|
||||||
|
tokenizer.push_to_hub("valid_org/test-tokenizer-org", use_auth_token=self._token)
|
||||||
|
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||||
|
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||||
|
|
||||||
|
# Reset repo
|
||||||
|
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
|
||||||
|
|
||||||
|
# Push to hub via save_pretrained
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
tokenizer.save_pretrained(
|
||||||
|
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, use_auth_token=self._token
|
||||||
|
)
|
||||||
|
|
||||||
|
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||||
|
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||||
|
|
||||||
|
@require_tokenizers
|
||||||
|
def test_push_to_hub_dynamic_tokenizer(self):
|
||||||
|
CustomTokenizer.register_for_auto_class()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
tokenizer = CustomTokenizer(vocab_file)
|
||||||
|
|
||||||
|
# No fast custom tokenizer
|
||||||
|
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||||
|
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
|
||||||
|
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
||||||
|
|
||||||
|
# Fast and slow custom tokenizer
|
||||||
|
CustomTokenizerFast.register_for_auto_class()
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
|
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||||
|
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||||
|
|
||||||
|
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
|
||||||
|
bert_tokenizer.save_pretrained(tmp_dir)
|
||||||
|
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
||||||
|
|
||||||
|
tokenizer.push_to_hub("test-dynamic-tokenizer", use_auth_token=self._token)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||||
|
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||||
|
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
|
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
|
||||||
|
)
|
||||||
|
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||||
|
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
||||||
|
|
||||||
|
|
||||||
|
class TrieTest(unittest.TestCase):
|
||||||
|
def test_trie(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("Hello 友達")
|
||||||
|
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||||
|
trie.add("Hello")
|
||||||
|
trie.data
|
||||||
|
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||||
|
|
||||||
|
def test_trie_split(self):
|
||||||
|
trie = Trie()
|
||||||
|
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||||
|
trie.add("[CLS]")
|
||||||
|
trie.add("extra_id_1")
|
||||||
|
trie.add("extra_id_100")
|
||||||
|
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||||
|
|
||||||
|
def test_trie_single(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("A")
|
||||||
|
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
||||||
|
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
||||||
|
|
||||||
|
def test_trie_final(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("TOKEN]")
|
||||||
|
trie.add("[SPECIAL_TOKEN]")
|
||||||
|
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||||
|
|
||||||
|
def test_trie_subtokens(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("A")
|
||||||
|
trie.add("P")
|
||||||
|
trie.add("[SPECIAL_TOKEN]")
|
||||||
|
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||||
|
|
||||||
|
def test_trie_suffix_tokens(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("AB")
|
||||||
|
trie.add("B")
|
||||||
|
trie.add("C")
|
||||||
|
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||||
|
|
||||||
|
def test_trie_skip(self):
|
||||||
|
trie = Trie()
|
||||||
|
trie.add("ABC")
|
||||||
|
trie.add("B")
|
||||||
|
trie.add("CD")
|
||||||
|
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
|
||||||
|
|
||||||
|
def test_cut_text_hardening(self):
|
||||||
|
# Even if the offsets are wrong, we necessarily output correct string
|
||||||
|
# parts.
|
||||||
|
trie = Trie()
|
||||||
|
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||||
|
self.assertEqual(parts, ["AB", "C"])
|
||||||
Reference in New Issue
Block a user