Move some test files (tets/test_xxx_utils.py) to tests/utils (#31730)
* move * move * move * move * Update tests/utils/test_image_processing_utils.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
457
tests/utils/test_cache_utils.py
Normal file
457
tests/utils/test_cache_utils.py
Normal file
@@ -0,0 +1,457 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
from parameterized import parameterized
|
||||
|
||||
from transformers import set_seed
|
||||
from transformers.testing_utils import (
|
||||
is_torch_available,
|
||||
require_auto_gptq,
|
||||
require_torch,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
DynamicCache,
|
||||
LlamaConfig,
|
||||
LlamaForCausalLM,
|
||||
SinkCache,
|
||||
StaticCache,
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
class CacheTest(unittest.TestCase):
|
||||
def test_dynamic_cache_retrocompatibility(self):
|
||||
"""Tests that we can convert back and forth between the legacy cache format and DynamicCache"""
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((2, 4, 8, 16))
|
||||
new_value = torch.rand((2, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Sanity check 1: they must have the same shapes
|
||||
self.assertTrue(len(legacy_cache), len(new_cache))
|
||||
for layer_idx in range(10):
|
||||
self.assertTrue(len(legacy_cache[layer_idx]), len(legacy_cache[layer_idx]))
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
legacy_cache[layer_idx][key_value_idx].shape == new_cache[layer_idx][key_value_idx].shape
|
||||
)
|
||||
|
||||
# Sanity check 2: we can get the sequence length in multiple ways with DynamicCache, and they return the
|
||||
# expected value
|
||||
self.assertTrue(legacy_cache[0][0].shape[-2] == new_cache[0][0].shape[-2] == new_cache.get_seq_length() == 8)
|
||||
|
||||
# Sanity check 3: they must be equal, and both support indexing
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(new_cache[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
# Test 1: We can convert from legacy to new with no changes
|
||||
from_legacy = DynamicCache.from_legacy_cache(legacy_cache)
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(from_legacy[layer_idx][key_value_idx], legacy_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
# Test 2: We can convert from new to legacy with no changes
|
||||
to_legacy = new_cache.to_legacy_cache()
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(to_legacy[layer_idx][key_value_idx], new_cache[layer_idx][key_value_idx])
|
||||
)
|
||||
|
||||
def test_reorder_cache_retrocompatibility(self):
|
||||
"""Tests that Cache.reorder_cache is retrocompatible with the legacy code path"""
|
||||
legacy_reorder_fn = LlamaForCausalLM._reorder_cache # An example of a legacy `_reorder_cache` function
|
||||
|
||||
legacy_cache = ()
|
||||
new_cache = DynamicCache()
|
||||
|
||||
# Creates a new cache with 10 layers in both formats
|
||||
for layer_idx in range(10):
|
||||
new_key = torch.rand((4, 4, 8, 16))
|
||||
new_value = torch.rand((4, 4, 8, 16))
|
||||
new_cache.update(new_key, new_value, layer_idx)
|
||||
legacy_cache += ((new_key, new_value),)
|
||||
|
||||
# Let's create some dummy beam indices. From the shape above, it is equivalent to the case where num_beams=4
|
||||
# and batch_size=1
|
||||
beam_idx = torch.randint(low=0, high=4, size=(4,))
|
||||
|
||||
legacy_cache_reordered = legacy_reorder_fn(legacy_cache, beam_idx)
|
||||
new_cache.reorder_cache(beam_idx)
|
||||
|
||||
# Let's check that the results are the same
|
||||
for layer_idx in range(10):
|
||||
for key_value_idx in range(2):
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
new_cache[layer_idx][key_value_idx], legacy_cache_reordered[layer_idx][key_value_idx]
|
||||
)
|
||||
)
|
||||
|
||||
def test_static_cache_mha_mqa_gqa(self):
|
||||
"""
|
||||
Tests that static cache works with multi-head attention (MHA), grouped query attention (GQA), and multi-query
|
||||
attention (MQA)
|
||||
"""
|
||||
|
||||
def _random_kvs(config):
|
||||
# shape for key and values: (batch_size, num_heads, seq_len, head_dim)
|
||||
random_keys = torch.rand(
|
||||
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
||||
device=torch_device,
|
||||
)
|
||||
random_values = torch.rand(
|
||||
(1, config.num_key_value_heads, 1, config.hidden_size // config.num_attention_heads),
|
||||
device=torch_device,
|
||||
)
|
||||
return random_keys, random_values
|
||||
|
||||
mha_config = LlamaConfig(num_attention_heads=32)
|
||||
mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mha_static_cache.update(
|
||||
*_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 32, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 32, 10, 128))
|
||||
|
||||
gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4)
|
||||
gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = gqa_static_cache.update(
|
||||
*_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 4, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 4, 10, 128))
|
||||
|
||||
mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1)
|
||||
mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device)
|
||||
cached_keys, cached_values = mqa_static_cache.update(
|
||||
*_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1)}
|
||||
)
|
||||
self.assertTrue(cached_keys.shape == (1, 1, 10, 128))
|
||||
self.assertTrue(cached_values.shape == (1, 1, 10, 128))
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class CacheIntegrationTest(unittest.TestCase):
|
||||
def test_dynamic_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["Here's everything I know about cats. Cats"], return_tensors="pt").to(model.device)
|
||||
|
||||
# DynamicCache and the legacy cache format should be equivalent
|
||||
set_seed(0)
|
||||
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=True, max_new_tokens=256, past_key_values=DynamicCache())
|
||||
self.assertListEqual(gen_out_legacy.tolist(), gen_out.tolist())
|
||||
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = (
|
||||
"Here's everything I know about cats. Cats are mysterious creatures. They can't talk, and they don't like "
|
||||
"to be held. They don't play fetch, and they don't like to be hugged. But they do like to be petted.\n"
|
||||
"Cats are also very independent. They don't like to be told what to do, and they don't like to be told "
|
||||
"what to eat. They are also very territorial. They don't like to share their food or their toys.\nCats "
|
||||
"are also very curious. They like to explore, and they like to play. They are also very fast. They can "
|
||||
"run very fast, and they can jump very high.\nCats are also very smart. They can learn tricks, and they "
|
||||
"can solve problems. They are also very playful. They like to play with toys, and they like to play with "
|
||||
"other cats.\nCats are also very affectionate. They like to be petted, and they like to be held. They "
|
||||
"also like to be scratched.\nCats are also very clean. They like to groom themselves, and they like to "
|
||||
"clean their litter box.\nCats are also very independent. They don't"
|
||||
)
|
||||
self.assertEqual(decoded[0], expected_text)
|
||||
|
||||
def test_dynamic_cache_batched(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
inputs = tokenizer(["A sequence: 1, 2, 3, 4, 5", "A sequence: A, B, C"], padding=True, return_tensors="pt").to(
|
||||
model.device
|
||||
)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10, past_key_values=DynamicCache())
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = ["A sequence: 1, 2, 3, 4, 5, 6, 7, 8,", "A sequence: A, B, C, D, E, F, G, H"]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
def test_dynamic_cache_beam_search(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf", padding_side="left")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-2-7b-hf", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
inputs = tokenizer(["The best color is"], return_tensors="pt").to(model.device)
|
||||
gen_out = model.generate(
|
||||
**inputs,
|
||||
do_sample=False,
|
||||
max_new_tokens=20,
|
||||
num_beams=2,
|
||||
num_return_sequences=2,
|
||||
)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
expected_text = [
|
||||
"The best color is the one that makes you feel good.\nThe best color is the one that makes you feel good",
|
||||
"The best color is the one that suits you.\nThe best color is the one that suits you. The",
|
||||
]
|
||||
self.assertListEqual(decoded, expected_text)
|
||||
|
||||
@require_auto_gptq
|
||||
def test_sink_cache_hard(self):
|
||||
tokenizer = AutoTokenizer.from_pretrained("TheBloke/LLaMa-7B-GPTQ")
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/LLaMa-7B-GPTQ", device_map="auto")
|
||||
|
||||
inputs = tokenizer(["Vaswani et al. (2017) introduced the Transformers"], return_tensors="pt").to(model.device)
|
||||
|
||||
# Set up the SinkCache. Using a small window length to contain computational complexity. If this example is run
|
||||
# without a SinkCache, the last few tokens are gibberish (ends in "of the of the of a of a of")
|
||||
cache = SinkCache(window_length=508, num_sink_tokens=4)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=3000, past_key_values=cache)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))
|
||||
|
||||
def test_sink_cache_iterative_prompts(self):
|
||||
"""Tests that SinkCache supports more than one new token at once, when shifting the cache"""
|
||||
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"HuggingFaceH4/zephyr-7b-beta", device_map="auto", torch_dtype=torch.float16
|
||||
)
|
||||
prompt = (
|
||||
"Compose an engaging travel blog post about a recent trip to Hawaii, highlighting cultural experiences "
|
||||
"and must-see attractions."
|
||||
)
|
||||
|
||||
# Prepare generation settings
|
||||
cache = SinkCache(window_length=256, num_sink_tokens=4)
|
||||
input_ids = torch.tensor([], device=model.device, dtype=torch.int)
|
||||
for _ in range(3):
|
||||
# Tokenize the prompt with the correct chat template
|
||||
chat = [{"role": "user", "content": prompt}]
|
||||
tokenized_chat = tokenizer.apply_chat_template(chat, return_tensors="pt", add_generation_prompt=True).to(
|
||||
model.device
|
||||
)
|
||||
input_ids = torch.cat((input_ids, tokenized_chat), dim=1)
|
||||
|
||||
# Perform the generation
|
||||
gen_out = model.generate(
|
||||
input_ids, do_sample=False, max_new_tokens=100, past_key_values=cache, use_cache=True
|
||||
)
|
||||
input_ids = gen_out
|
||||
|
||||
# We went well beyond the cache length
|
||||
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
|
||||
|
||||
# And it still produces a coherent english
|
||||
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
|
||||
last_output = (
|
||||
"<|assistant|>\nAs the sun began to set over the Pacific Ocean, I found myself standing on the shores of "
|
||||
"Waikiki Beach, my heart filled with awe and wonder. I had just returned from a two-week journey to the "
|
||||
"beautiful island of Hawaii, and it had been an unforgettable experience filled with cultural experiences "
|
||||
"and must-see attractions that left me breathless.\n\nOne of the most memorable experiences of my trip "
|
||||
"was visiting the historic district of Honolulu. Here,"
|
||||
)
|
||||
self.assertTrue(decoded[0].endswith(last_output))
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_decoding_pad_left(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.forward = torch.compile(model.forward)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@require_torch_gpu
|
||||
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
|
||||
def test_static_cache_greedy_decoding_pad_right(self, attn_implementation):
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color isЋ the one that complements the skin tone of",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation=attn_implementation,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
set_seed(0)
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, dynamic"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model.generation_config.cache_implementation = "static"
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, eager"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
set_seed(0)
|
||||
model._forward = model.forward
|
||||
compiled_forward = torch.compile(model.forward)
|
||||
|
||||
def compiled(func, input_ids, **kwargs):
|
||||
return func(input_ids, **kwargs)
|
||||
|
||||
def call(input_ids, **kwargs):
|
||||
if input_ids.shape[-1] == 1:
|
||||
return compiled(compiled_forward, input_ids, **kwargs)
|
||||
|
||||
return model._forward(input_ids, **kwargs)
|
||||
|
||||
model.forward = call
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
with self.subTest(f"{attn_implementation}, static, compiled"):
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
def test_dynamic_cache_extra_left_padding(self):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the dynamic cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
# Now with extra left-padding
|
||||
inputs_expanded = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
pad_to_multiple_of=32,
|
||||
).to(model.device)
|
||||
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
||||
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
def test_static_cache_extra_left_padding(self):
|
||||
"""Tests that adding extra left-padding does not affect the generation with the static cache"""
|
||||
EXPECTED_GENERATION = [
|
||||
"The best color is the one that complements the skin tone of the",
|
||||
"We should not undermind the issues at hand.\nWe should not undermind the issues",
|
||||
]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
|
||||
)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"NousResearch/Llama-2-7b-chat-hf",
|
||||
torch_dtype=torch.bfloat16,
|
||||
).to(torch_device)
|
||||
inputs = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
|
||||
).to(model.device)
|
||||
|
||||
model.generation_config.cache_implementation = "static"
|
||||
|
||||
gen_out = model.generate(**inputs, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
# Now with extra left-padding
|
||||
inputs_expanded = tokenizer(
|
||||
["The best color is", "We should not undermind the issues at hand"],
|
||||
padding=True,
|
||||
return_tensors="pt",
|
||||
pad_to_multiple_of=32,
|
||||
).to(model.device)
|
||||
self.assertTrue(inputs.input_ids.shape[1] < inputs_expanded.input_ids.shape[1])
|
||||
gen_out = model.generate(**inputs_expanded, do_sample=False, max_new_tokens=10)
|
||||
decoded = tokenizer.batch_decode(gen_out, skip_special_tokens=True)
|
||||
self.assertListEqual(decoded, EXPECTED_GENERATION)
|
||||
|
||||
@unittest.skip(reason="TODO @gante static cache's does not support beam search yet")
|
||||
def test_static_cache_beam_search(self):
|
||||
pass
|
||||
315
tests/utils/test_configuration_utils.py
Normal file
315
tests/utils/test_configuration_utils.py
Normal file
@@ -0,0 +1,315 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoConfig, BertConfig, GPT2Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
|
||||
|
||||
config_common_kwargs = {
|
||||
"return_dict": False,
|
||||
"output_hidden_states": True,
|
||||
"output_attentions": True,
|
||||
"torchscript": True,
|
||||
"torch_dtype": "float16",
|
||||
"use_bfloat16": True,
|
||||
"tf_legacy_loss": True,
|
||||
"pruned_heads": {"a": 1},
|
||||
"tie_word_embeddings": False,
|
||||
"is_decoder": True,
|
||||
"cross_attention_hidden_size": 128,
|
||||
"add_cross_attention": True,
|
||||
"tie_encoder_decoder": True,
|
||||
"max_length": 50,
|
||||
"min_length": 3,
|
||||
"do_sample": True,
|
||||
"early_stopping": True,
|
||||
"num_beams": 3,
|
||||
"num_beam_groups": 3,
|
||||
"diversity_penalty": 0.5,
|
||||
"temperature": 2.0,
|
||||
"top_k": 10,
|
||||
"top_p": 0.7,
|
||||
"typical_p": 0.2,
|
||||
"repetition_penalty": 0.8,
|
||||
"length_penalty": 0.8,
|
||||
"no_repeat_ngram_size": 5,
|
||||
"encoder_no_repeat_ngram_size": 5,
|
||||
"bad_words_ids": [1, 2, 3],
|
||||
"num_return_sequences": 3,
|
||||
"chunk_size_feed_forward": 5,
|
||||
"output_scores": True,
|
||||
"return_dict_in_generate": True,
|
||||
"forced_bos_token_id": 2,
|
||||
"forced_eos_token_id": 3,
|
||||
"remove_invalid_values": True,
|
||||
"architectures": ["BertModel"],
|
||||
"finetuning_task": "translation",
|
||||
"id2label": {0: "label"},
|
||||
"label2id": {"label": "0"},
|
||||
"tokenizer_class": "BertTokenizerFast",
|
||||
"prefix": "prefix",
|
||||
"bos_token_id": 6,
|
||||
"pad_token_id": 7,
|
||||
"eos_token_id": 8,
|
||||
"sep_token_id": 9,
|
||||
"decoder_start_token_id": 10,
|
||||
"exponential_decay_length_penalty": (5, 1.01),
|
||||
"suppress_tokens": [0, 1],
|
||||
"begin_suppress_tokens": 2,
|
||||
"task_specific_params": {"translation": "some_params"},
|
||||
"problem_type": "regression",
|
||||
}
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ConfigPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-config-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-config")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.push_to_hub("test-config", token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-config")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id="test-config", push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained(f"{USER}/test-config")
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
config.push_to_hub("valid_org/test-config-org", token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-config-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
config.save_pretrained(tmp_dir, repo_id="valid_org/test-config-org", push_to_hub=True, token=self._token)
|
||||
|
||||
new_config = BertConfig.from_pretrained("valid_org/test-config-org")
|
||||
for k, v in config.to_dict().items():
|
||||
if k != "transformers_version":
|
||||
self.assertEqual(v, getattr(new_config, k))
|
||||
|
||||
def test_push_to_hub_dynamic_config(self):
|
||||
CustomConfig.register_for_auto_class()
|
||||
config = CustomConfig(attribute=42)
|
||||
|
||||
config.push_to_hub("test-dynamic-config", token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(config.auto_map, {"AutoConfig": "custom_configuration.CustomConfig"})
|
||||
|
||||
new_config = AutoConfig.from_pretrained(f"{USER}/test-dynamic-config", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(new_config.__class__.__name__, "CustomConfig")
|
||||
self.assertEqual(new_config.attribute, 42)
|
||||
|
||||
|
||||
class ConfigTestUtils(unittest.TestCase):
|
||||
def test_config_from_string(self):
|
||||
c = GPT2Config()
|
||||
|
||||
# attempt to modify each of int/float/bool/str config records and verify they were updated
|
||||
n_embd = c.n_embd + 1 # int
|
||||
resid_pdrop = c.resid_pdrop + 1.0 # float
|
||||
scale_attn_weights = not c.scale_attn_weights # bool
|
||||
summary_type = c.summary_type + "foo" # str
|
||||
c.update_from_string(
|
||||
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
|
||||
)
|
||||
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
|
||||
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
|
||||
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
|
||||
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
|
||||
|
||||
def test_config_common_kwargs_is_complete(self):
|
||||
base_config = PretrainedConfig()
|
||||
missing_keys = [key for key in base_config.__dict__ if key not in config_common_kwargs]
|
||||
# If this part of the test fails, you have arguments to addin config_common_kwargs above.
|
||||
self.assertListEqual(
|
||||
missing_keys,
|
||||
[
|
||||
"is_encoder_decoder",
|
||||
"_name_or_path",
|
||||
"_commit_hash",
|
||||
"_attn_implementation_internal",
|
||||
"transformers_version",
|
||||
],
|
||||
)
|
||||
keys_with_defaults = [key for key, value in config_common_kwargs.items() if value == getattr(base_config, key)]
|
||||
if len(keys_with_defaults) > 0:
|
||||
raise ValueError(
|
||||
"The following keys are set with the default values in"
|
||||
" `test_configuration_common.config_common_kwargs` pick another value for them:"
|
||||
f" {', '.join(keys_with_defaults)}."
|
||||
)
|
||||
|
||||
def test_nested_config_load_from_dict(self):
|
||||
config = AutoConfig.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-CLIPModel", text_config={"num_hidden_layers": 2}
|
||||
)
|
||||
self.assertNotIsInstance(config.text_config, dict)
|
||||
self.assertEqual(config.text_config.__class__.__name__, "CLIPTextConfig")
|
||||
|
||||
def test_from_pretrained_subfolder(self):
|
||||
with self.assertRaises(OSError):
|
||||
# config is in subfolder, the following should not work without specifying the subfolder
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder")
|
||||
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert-subfolder", subfolder="bert")
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = BertConfig.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_local_versioning(self):
|
||||
configuration = AutoConfig.from_pretrained("google-bert/bert-base-cased")
|
||||
configuration.configuration_files = ["config.4.0.0.json"]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
configuration.hidden_size = 2
|
||||
json.dump(configuration.to_dict(), open(os.path.join(tmp_dir, "config.4.0.0.json"), "w"))
|
||||
|
||||
# This should pick the new configuration file as the version of Transformers is > 4.0.0
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
|
||||
# Will need to be adjusted if we reach v42 and this test is still here.
|
||||
# Should pick the old configuration file as the version of Transformers is < 4.42.0
|
||||
configuration.configuration_files = ["config.42.0.0.json"]
|
||||
configuration.hidden_size = 768
|
||||
configuration.save_pretrained(tmp_dir)
|
||||
shutil.move(os.path.join(tmp_dir, "config.4.0.0.json"), os.path.join(tmp_dir, "config.42.0.0.json"))
|
||||
new_configuration = AutoConfig.from_pretrained(tmp_dir)
|
||||
self.assertEqual(new_configuration.hidden_size, 768)
|
||||
|
||||
def test_repo_versioning_before(self):
|
||||
# This repo has two configuration files, one for v4.0.0 and above with a different hidden size.
|
||||
repo = "hf-internal-testing/test-two-configs"
|
||||
|
||||
import transformers as new_transformers
|
||||
|
||||
new_transformers.configuration_utils.__version__ = "v4.0.0"
|
||||
new_configuration, kwargs = new_transformers.models.auto.AutoConfig.from_pretrained(
|
||||
repo, return_unused_kwargs=True
|
||||
)
|
||||
self.assertEqual(new_configuration.hidden_size, 2)
|
||||
# This checks `_configuration_file` ia not kept in the kwargs by mistake.
|
||||
self.assertDictEqual(kwargs, {})
|
||||
|
||||
# Testing an older version by monkey-patching the version in the module it's used.
|
||||
import transformers as old_transformers
|
||||
|
||||
old_transformers.configuration_utils.__version__ = "v3.0.0"
|
||||
old_configuration = old_transformers.models.auto.AutoConfig.from_pretrained(repo)
|
||||
self.assertEqual(old_configuration.hidden_size, 768)
|
||||
|
||||
def test_saving_config_with_custom_generation_kwargs_raises_warning(self):
|
||||
config = BertConfig(min_length=3) # `min_length = 3` is a non-default generation kwarg
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with self.assertLogs("transformers.configuration_utils", level="WARNING") as logs:
|
||||
config.save_pretrained(tmp_dir)
|
||||
self.assertEqual(len(logs.output), 1)
|
||||
self.assertIn("min_length", logs.output[0])
|
||||
|
||||
def test_has_non_default_generation_parameters(self):
|
||||
config = BertConfig()
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
config = BertConfig(min_length=3)
|
||||
self.assertTrue(config._has_non_default_generation_parameters())
|
||||
config = BertConfig(min_length=0) # `min_length = 0` is a default generation kwarg
|
||||
self.assertFalse(config._has_non_default_generation_parameters())
|
||||
|
||||
def test_loading_config_do_not_raise_future_warnings(self):
|
||||
"""Regression test for https://github.com/huggingface/transformers/issues/31002."""
|
||||
# Loading config should not raise a FutureWarning. It was the case before.
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("error")
|
||||
PretrainedConfig.from_pretrained("bert-base-uncased")
|
||||
138
tests/utils/test_feature_extraction_utils.py
Normal file
138
tests/utils/test_feature_extraction_utils.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoFeatureExtractor, Wav2Vec2FeatureExtractor
|
||||
from transformers.testing_utils import TOKEN, USER, get_tests_dir, is_staging_test
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_feature_extraction import CustomFeatureExtractor # noqa E402
|
||||
|
||||
|
||||
SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR = get_tests_dir("fixtures")
|
||||
|
||||
|
||||
class FeatureExtractorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = Wav2Vec2FeatureExtractor.from_pretrained("hf-internal-testing/tiny-random-wav2vec2")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class FeatureExtractorPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-feature-extractor-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-feature-extractor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub("test-feature-extractor", token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-feature-extractor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
tmp_dir, repo_id="test-feature-extractor", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"{USER}/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
feature_extractor.push_to_hub("valid_org/test-feature-extractor", token=self._token)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-feature-extractor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
feature_extractor.save_pretrained(
|
||||
tmp_dir, repo_id="valid_org/test-feature-extractor-org", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("valid_org/test-feature-extractor-org")
|
||||
for k, v in feature_extractor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_feature_extractor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_feature_extractor(self):
|
||||
CustomFeatureExtractor.register_for_auto_class()
|
||||
feature_extractor = CustomFeatureExtractor.from_pretrained(SAMPLE_FEATURE_EXTRACTION_CONFIG_DIR)
|
||||
|
||||
feature_extractor.push_to_hub("test-dynamic-feature-extractor", token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
feature_extractor.auto_map,
|
||||
{"AutoFeatureExtractor": "custom_feature_extraction.CustomFeatureExtractor"},
|
||||
)
|
||||
|
||||
new_feature_extractor = AutoFeatureExtractor.from_pretrained(
|
||||
f"{USER}/test-dynamic-feature-extractor", trust_remote_code=True
|
||||
)
|
||||
# Can't make an isinstance check because the new_feature_extractor is from the CustomFeatureExtractor class of a dynamic module
|
||||
self.assertEqual(new_feature_extractor.__class__.__name__, "CustomFeatureExtractor")
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
# Copyright 2024 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,9 +13,140 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import AutoImageProcessor, ViTImageProcessor
|
||||
from transformers.image_processing_utils import get_size_dict
|
||||
from transformers.testing_utils import TOKEN, USER, get_tests_dir, is_staging_test
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_image_processing import CustomImageProcessor # noqa E402
|
||||
|
||||
|
||||
SAMPLE_IMAGE_PROCESSING_CONFIG_DIR = get_tests_dir("fixtures")
|
||||
|
||||
|
||||
class ImageProcessorUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = ViTImageProcessor.from_pretrained("hf-internal-testing/tiny-random-vit")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_image_processor_from_pretrained_subfolder(self):
|
||||
with self.assertRaises(OSError):
|
||||
# config is in subfolder, the following should not work without specifying the subfolder
|
||||
_ = AutoImageProcessor.from_pretrained("hf-internal-testing/stable-diffusion-all-variants")
|
||||
|
||||
config = AutoImageProcessor.from_pretrained(
|
||||
"hf-internal-testing/stable-diffusion-all-variants", subfolder="feature_extractor"
|
||||
)
|
||||
|
||||
self.assertIsNotNone(config)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class ImageProcessorPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-image-processor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-image-processor-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-image-processor")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub("test-image-processor", token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-image-processor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(
|
||||
tmp_dir, repo_id="test-image-processor", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained(f"{USER}/test-image-processor")
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
image_processor = ViTImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
image_processor.push_to_hub("valid_org/test-image-processor", token=self._token)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor")
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-image-processor")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
image_processor.save_pretrained(
|
||||
tmp_dir, repo_id="valid_org/test-image-processor-org", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_image_processor = ViTImageProcessor.from_pretrained("valid_org/test-image-processor-org")
|
||||
for k, v in image_processor.__dict__.items():
|
||||
self.assertEqual(v, getattr(new_image_processor, k))
|
||||
|
||||
def test_push_to_hub_dynamic_image_processor(self):
|
||||
CustomImageProcessor.register_for_auto_class()
|
||||
image_processor = CustomImageProcessor.from_pretrained(SAMPLE_IMAGE_PROCESSING_CONFIG_DIR)
|
||||
|
||||
image_processor.push_to_hub("test-dynamic-image-processor", token=self._token)
|
||||
|
||||
# This has added the proper auto_map field to the config
|
||||
self.assertDictEqual(
|
||||
image_processor.auto_map,
|
||||
{"AutoImageProcessor": "custom_image_processing.CustomImageProcessor"},
|
||||
)
|
||||
|
||||
new_image_processor = AutoImageProcessor.from_pretrained(
|
||||
f"{USER}/test-dynamic-image-processor", trust_remote_code=True
|
||||
)
|
||||
# Can't make an isinstance check because the new_image_processor is from the CustomImageProcessor class of a dynamic module
|
||||
self.assertEqual(new_image_processor.__class__.__name__, "CustomImageProcessor")
|
||||
|
||||
|
||||
class ImageProcessingUtilsTester(unittest.TestCase):
|
||||
|
||||
403
tests/utils/test_modeling_flax_utils.py
Normal file
403
tests/utils/test_modeling_flax_utils.py
Normal file
@@ -0,0 +1,403 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import HfFolder, delete_repo, snapshot_download
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import BertConfig, BertModel, is_flax_available, is_torch_available
|
||||
from transformers.testing_utils import (
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
require_safetensors,
|
||||
require_torch,
|
||||
)
|
||||
from transformers.utils import FLAX_WEIGHTS_NAME, SAFE_WEIGHTS_NAME, logging
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.traverse_util import flatten_dict
|
||||
|
||||
from transformers import FlaxBertModel
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_flax
|
||||
@is_staging_test
|
||||
class FlaxModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-model-flax")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-model-flax-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
model.push_to_hub("test-model-flax", token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-model-flax")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id="test-model-flax", push_to_hub=True, token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(f"{USER}/test-model-flax")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = FlaxBertModel(config)
|
||||
model.push_to_hub("valid_org/test-model-flax-org", token=self._token)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-model-flax-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
tmp_dir, repo_id="valid_org/test-model-flax-org", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained("valid_org/test-model-flax-org")
|
||||
|
||||
base_params = flatten_dict(unfreeze(model.params))
|
||||
new_params = flatten_dict(unfreeze(new_model.params))
|
||||
|
||||
for key in base_params.keys():
|
||||
max_diff = (base_params[key] - new_params[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
|
||||
def check_models_equal(model1, model2):
|
||||
models_are_equal = True
|
||||
flat_params_1 = flatten_dict(model1.params)
|
||||
flat_params_2 = flatten_dict(model2.params)
|
||||
for key in flat_params_1.keys():
|
||||
if np.sum(np.abs(flat_params_1[key] - flat_params_2[key])) > 1e-4:
|
||||
models_are_equal = False
|
||||
|
||||
return models_are_equal
|
||||
|
||||
|
||||
@require_flax
|
||||
class FlaxModelUtilsTest(unittest.TestCase):
|
||||
def test_model_from_pretrained_subfolder(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder))
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_subfolder_sharded(self):
|
||||
config = BertConfig.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = FlaxBertModel(config)
|
||||
|
||||
subfolder = "bert"
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(os.path.join(tmp_dir, subfolder), max_shard_size="10KB")
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model_loaded = FlaxBertModel.from_pretrained(tmp_dir, subfolder=subfolder)
|
||||
|
||||
self.assertTrue(check_models_equal(model, model_loaded))
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-subfolder"
|
||||
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_model_from_pretrained_hub_subfolder_sharded(self):
|
||||
subfolder = "bert"
|
||||
model_id = "hf-internal-testing/tiny-random-bert-sharded-subfolder"
|
||||
with self.assertRaises(OSError):
|
||||
_ = FlaxBertModel.from_pretrained(model_id)
|
||||
|
||||
model = FlaxBertModel.from_pretrained(model_id, subfolder=subfolder)
|
||||
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
|
||||
# No msgpack file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, FLAX_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_flax
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir)
|
||||
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
# Can load from the Flax-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-only")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-flax-only", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-flax-safetensors-only", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
flax_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-msgpack")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt_bf16(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
import torch
|
||||
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
model.save_pretrained(tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(tmp)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-msgpack", cache_dir=tmp)
|
||||
flax_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = FlaxBertModel.from_pretrained(location)
|
||||
|
||||
self.assertTrue(check_models_equal(flax_model, safetensors_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_msgpack_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download msgpack weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_msgpack_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download msgpack weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||
FlaxBertModel.from_pretrained(location)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_flax(self):
|
||||
model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
@require_torch
|
||||
@is_pt_flax_cross_test
|
||||
def test_safetensors_flax_from_torch(self):
|
||||
hub_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = FlaxBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
self.assertTrue(check_models_equal(hub_model, new_model))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = snapshot_download(
|
||||
"hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded", cache_dir=tmp_dir
|
||||
)
|
||||
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
FlaxBertModel.from_pretrained(path)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_flax_from_sharded_msgpack_with_sharded_safetensors_hub(self):
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
# This should discard the safetensors weights in favor of the msgpack sharded weights
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-flax-safetensors-msgpack-sharded")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_from_pt_bf16(self):
|
||||
# This should not raise; should be able to load bf16-serialized torch safetensors without issue
|
||||
# and without torch.
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
@require_torch
|
||||
@require_safetensors
|
||||
@is_pt_flax_cross_test
|
||||
def test_from_pt_bf16(self):
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
model.to(torch.bfloat16)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=False)
|
||||
|
||||
logger = logging.get_logger("transformers.modeling_flax_utils")
|
||||
|
||||
with CaptureLogger(logger) as cl:
|
||||
new_model = FlaxBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-bf16")
|
||||
|
||||
self.assertTrue(
|
||||
"Some of the weights of FlaxBertModel were initialized in bfloat16 precision from the model checkpoint"
|
||||
in cl.out
|
||||
)
|
||||
|
||||
flat_params_1 = flatten_dict(new_model.params)
|
||||
for value in flat_params_1.values():
|
||||
self.assertEqual(value.dtype, "bfloat16")
|
||||
802
tests/utils/test_modeling_tf_utils.py
Normal file
802
tests/utils/test_modeling_tf_utils.py
Normal file
@@ -0,0 +1,802 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
|
||||
from huggingface_hub import HfFolder, Repository, delete_repo, snapshot_download
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.testing_utils import ( # noqa: F401
|
||||
TOKEN,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
_tf_gpu_memory_limit,
|
||||
is_pt_tf_cross_test,
|
||||
is_staging_test,
|
||||
require_safetensors,
|
||||
require_tf,
|
||||
require_torch,
|
||||
slow,
|
||||
)
|
||||
from transformers.utils import (
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
SAFE_WEIGHTS_NAME,
|
||||
TF2_WEIGHTS_INDEX_NAME,
|
||||
TF2_WEIGHTS_NAME,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
import h5py
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
from transformers import (
|
||||
BertConfig,
|
||||
PreTrainedModel,
|
||||
PushToHubCallback,
|
||||
RagRetriever,
|
||||
TFAutoModel,
|
||||
TFBertForMaskedLM,
|
||||
TFBertForSequenceClassification,
|
||||
TFBertModel,
|
||||
TFPreTrainedModel,
|
||||
TFRagModel,
|
||||
)
|
||||
from transformers.modeling_tf_utils import keras, tf_shard_checkpoint, unpack_inputs
|
||||
from transformers.tf_utils import stable_softmax
|
||||
|
||||
tf.config.experimental.enable_tensor_float_32_execution(False)
|
||||
|
||||
if _tf_gpu_memory_limit is not None:
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
for gpu in gpus:
|
||||
# Restrict TensorFlow to only allocate x GB of memory on the GPUs
|
||||
try:
|
||||
tf.config.set_logical_device_configuration(
|
||||
gpu, [tf.config.LogicalDeviceConfiguration(memory_limit=_tf_gpu_memory_limit)]
|
||||
)
|
||||
logical_gpus = tf.config.list_logical_devices("GPU")
|
||||
print("Logical GPUs", logical_gpus)
|
||||
except RuntimeError as e:
|
||||
# Virtual devices must be set before GPUs have been initialized
|
||||
print(e)
|
||||
|
||||
if is_torch_available():
|
||||
from transformers import BertModel
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFModelUtilsTest(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the model.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
# tests whether the unpack_inputs function behaves as expected
|
||||
def test_unpack_inputs(self):
|
||||
class DummyModel:
|
||||
def __init__(self):
|
||||
config_kwargs = {"output_attentions": False, "output_hidden_states": False, "return_dict": False}
|
||||
self.config = PretrainedConfig(**config_kwargs)
|
||||
self.main_input_name = "input_ids"
|
||||
|
||||
@unpack_inputs
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
past_key_values=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
):
|
||||
return input_ids, past_key_values, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
@unpack_inputs
|
||||
def foo(self, pixel_values, output_attentions=None, output_hidden_states=None, return_dict=None):
|
||||
return pixel_values, output_attentions, output_hidden_states, return_dict
|
||||
|
||||
dummy_model = DummyModel()
|
||||
input_ids = tf.constant([0, 1, 2, 3], dtype=tf.int32)
|
||||
past_key_values = tf.constant([4, 5, 6, 7], dtype=tf.int32)
|
||||
pixel_values = tf.constant([8, 9, 10, 11], dtype=tf.int32)
|
||||
|
||||
# test case 1: Pass inputs as keyword arguments; Booleans are inherited from the config.
|
||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past_key_values)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 2: Same as above, but with positional arguments.
|
||||
output = dummy_model.call(input_ids, past_key_values)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past_key_values)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 3: We can also pack everything in the first input.
|
||||
output = dummy_model.call(input_ids={"input_ids": input_ids, "past_key_values": past_key_values})
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past_key_values)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertFalse(output[4])
|
||||
|
||||
# test case 4: Explicit boolean arguments should override the config.
|
||||
output = dummy_model.call(
|
||||
input_ids=input_ids, past_key_values=past_key_values, output_attentions=False, return_dict=True
|
||||
)
|
||||
tf.debugging.assert_equal(output[0], input_ids)
|
||||
tf.debugging.assert_equal(output[1], past_key_values)
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
self.assertTrue(output[4])
|
||||
|
||||
# test case 5: Unexpected arguments should raise an exception.
|
||||
with self.assertRaises(ValueError):
|
||||
output = dummy_model.call(input_ids=input_ids, past_key_values=past_key_values, foo="bar")
|
||||
|
||||
# test case 6: the decorator is independent from `main_input_name` -- it treats the first argument of the
|
||||
# decorated function as its main input.
|
||||
output = dummy_model.foo(pixel_values=pixel_values)
|
||||
tf.debugging.assert_equal(output[0], pixel_values)
|
||||
self.assertFalse(output[1])
|
||||
self.assertFalse(output[2])
|
||||
self.assertFalse(output[3])
|
||||
|
||||
# Tests whether the stable softmax is stable on CPU, with and without XLA
|
||||
def test_xla_stable_softmax(self):
|
||||
large_penalty = -1e9
|
||||
n_tokens = 10
|
||||
batch_size = 8
|
||||
|
||||
def masked_softmax(x, boolean_mask):
|
||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||
masked_x = x + numerical_mask
|
||||
return stable_softmax(masked_x)
|
||||
|
||||
xla_masked_softmax = tf.function(masked_softmax, jit_compile=True)
|
||||
xla_stable_softmax = tf.function(stable_softmax, jit_compile=True)
|
||||
x = tf.random.normal((batch_size, n_tokens))
|
||||
|
||||
# Same outcome regardless of the boolean mask here
|
||||
masked_tokens = random.randint(0, n_tokens)
|
||||
boolean_mask = tf.convert_to_tensor([[1] * (n_tokens - masked_tokens) + [0] * masked_tokens], dtype=tf.int32)
|
||||
|
||||
# We can randomly mask a random numerical input OUTSIDE XLA
|
||||
numerical_mask = (1.0 - tf.cast(boolean_mask, dtype=tf.float32)) * large_penalty
|
||||
masked_x = x + numerical_mask
|
||||
xla_out = xla_stable_softmax(masked_x)
|
||||
out = stable_softmax(masked_x)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
# The stable softmax has the same output as the original softmax
|
||||
unstable_out = tf.nn.softmax(masked_x)
|
||||
assert tf.experimental.numpy.allclose(unstable_out, out)
|
||||
|
||||
# We can randomly mask a random numerical input INSIDE XLA
|
||||
xla_out = xla_masked_softmax(x, boolean_mask)
|
||||
out = masked_softmax(x, boolean_mask)
|
||||
assert tf.experimental.numpy.allclose(xla_out, out)
|
||||
|
||||
def test_checkpoint_sharding_from_hub(self):
|
||||
model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
# the model above is the same as the model below, just a sharded version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_sharded_checkpoint_with_prefix(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", load_weight_prefix="a/b")
|
||||
sharded_model = TFBertModel.from_pretrained("ArthurZ/tiny-random-bert-sharded", load_weight_prefix="a/b")
|
||||
for p1, p2 in zip(model.weights, sharded_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
self.assertTrue(p1.name.startswith("a/b/"))
|
||||
self.assertTrue(p2.name.startswith("a/b/"))
|
||||
|
||||
def test_sharded_checkpoint_transfer(self):
|
||||
# If this doesn't throw an error then the test passes
|
||||
TFBertForSequenceClassification.from_pretrained("ArthurZ/tiny-random-bert-sharded")
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_local_from_pt(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
_ = Repository(local_dir=tmp_dir, clone_from="hf-internal-testing/tiny-random-bert-sharded")
|
||||
model = TFBertModel.from_pretrained(tmp_dir, from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_loading_with_prefix_from_pt(self):
|
||||
model = TFBertModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-bert", from_pt=True, load_weight_prefix="a/b"
|
||||
)
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert", from_pt=True)
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
self.assertTrue(p1.name.startswith("a/b/"))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_checkpoint_sharding_hub_from_pt(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-sharded", from_pt=True)
|
||||
# the model above is the same as the model below, just a sharded pytorch version.
|
||||
ref_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
def test_shard_checkpoint(self):
|
||||
# This is the model we will use, total size 340,000 bytes.
|
||||
model = keras.Sequential(
|
||||
[
|
||||
keras.layers.Dense(200, use_bias=False), # size 80,000
|
||||
keras.layers.Dense(200, use_bias=False), # size 160,000
|
||||
keras.layers.Dense(100, use_bias=False), # size 80,000
|
||||
keras.layers.Dense(50, use_bias=False), # size 20,000
|
||||
]
|
||||
)
|
||||
inputs = tf.zeros((1, 100), dtype=tf.float32)
|
||||
model(inputs)
|
||||
weights = model.weights
|
||||
weights_dict = {w.name: w for w in weights}
|
||||
with self.subTest("No shard when max size is bigger than model size"):
|
||||
shards, index = tf_shard_checkpoint(weights)
|
||||
self.assertIsNone(index)
|
||||
self.assertDictEqual(shards, {TF2_WEIGHTS_NAME: weights})
|
||||
|
||||
with self.subTest("Test sharding, no weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="300kB")
|
||||
# Split is first two layers then last two.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_1/kernel:0": "tf_model-00001-of-00002.h5",
|
||||
"dense_2/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
"dense_3/kernel:0": "tf_model-00002-of-00002.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"], weights_dict["dense_1/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(shards, {"tf_model-00001-of-00002.h5": shard1, "tf_model-00002-of-00002.h5": shard2})
|
||||
|
||||
with self.subTest("Test sharding with weights bigger than max size"):
|
||||
shards, index = tf_shard_checkpoint(weights, max_shard_size="100kB")
|
||||
# Split is first layer, second layer then last 2.
|
||||
self.assertDictEqual(
|
||||
index,
|
||||
{
|
||||
"metadata": {"total_size": 340000},
|
||||
"weight_map": {
|
||||
"dense/kernel:0": "tf_model-00001-of-00003.h5",
|
||||
"dense_1/kernel:0": "tf_model-00002-of-00003.h5",
|
||||
"dense_2/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
"dense_3/kernel:0": "tf_model-00003-of-00003.h5",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
shard1 = [weights_dict["dense/kernel:0"]]
|
||||
shard2 = [weights_dict["dense_1/kernel:0"]]
|
||||
shard3 = [weights_dict["dense_2/kernel:0"], weights_dict["dense_3/kernel:0"]]
|
||||
self.assertDictEqual(
|
||||
shards,
|
||||
{
|
||||
"tf_model-00001-of-00003.h5": shard1,
|
||||
"tf_model-00002-of-00003.h5": shard2,
|
||||
"tf_model-00003-of-00003.h5": shard3,
|
||||
},
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_special_layer_name_sharding(self):
|
||||
retriever = RagRetriever.from_pretrained("facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True)
|
||||
model = TFRagModel.from_pretrained("facebook/rag-token-nq", retriever=retriever)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
ref_model = TFRagModel.from_pretrained(tmp_dir, retriever=retriever)
|
||||
for p1, p2 in zip(model.weights, ref_model.weights):
|
||||
assert np.allclose(p1.numpy(), p2.numpy())
|
||||
|
||||
@require_safetensors
|
||||
def test_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
|
||||
# Check a file is bigger than max_size only when it has a single weight
|
||||
for shard_file, size in shard_to_size.items():
|
||||
if max_size.endswith("kiB"):
|
||||
max_size_int = int(max_size[:-3]) * 2**10
|
||||
else:
|
||||
max_size_int = int(max_size[:-2]) * 10**3
|
||||
# Note: pickle adds some junk so the weight of the file can end up being slightly bigger than
|
||||
# the size asked for (since we count parameters)
|
||||
if size >= max_size_int + 50000:
|
||||
with h5py.File(shard_file, "r") as state_file:
|
||||
self.assertEqual(len(state_file), 1)
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".h5")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model.build_in_name_scope()
|
||||
new_model.build_in_name_scope()
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
def test_safetensors_checkpoint_sharding_local(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# We use the same folder for various sizes to make sure a new save erases the old checkpoint.
|
||||
for max_size in ["150kB", "150kiB", "200kB", "200kiB"]:
|
||||
model.save_pretrained(tmp_dir, max_shard_size=max_size, safe_serialization=True)
|
||||
|
||||
# Get each shard file and its size
|
||||
shard_to_size = {}
|
||||
for shard in os.listdir(tmp_dir):
|
||||
if shard.endswith(".h5"):
|
||||
shard_file = os.path.join(tmp_dir, shard)
|
||||
shard_to_size[shard_file] = os.path.getsize(shard_file)
|
||||
|
||||
index_file = os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)
|
||||
# Check there is an index but no regular weight file
|
||||
self.assertTrue(os.path.isfile(index_file))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
# Check the index and the shard files found match
|
||||
with open(index_file, "r", encoding="utf-8") as f:
|
||||
index = json.loads(f.read())
|
||||
|
||||
all_shards = set(index["weight_map"].values())
|
||||
shards_found = {f for f in os.listdir(tmp_dir) if f.endswith(".safetensors")}
|
||||
self.assertSetEqual(all_shards, shards_found)
|
||||
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model.build_in_name_scope()
|
||||
new_model.build_in_name_scope()
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
@require_safetensors
|
||||
def test_bfloat16_torch_loading(self):
|
||||
# Assert that neither of these raise an error - both repos contain bfloat16 tensors
|
||||
model1 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16-pt", from_pt=True)
|
||||
model2 = TFAutoModel.from_pretrained("Rocketknight1/tiny-random-gpt2-bfloat16") # PT-format safetensors
|
||||
# Check that PT and safetensors loading paths end up with the same values
|
||||
for weight1, weight2 in zip(model1.weights, model2.weights):
|
||||
self.assertTrue(tf.reduce_all(weight1 == weight2))
|
||||
|
||||
@slow
|
||||
def test_save_pretrained_signatures(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Short custom TF signature function.
|
||||
# `input_signature` is specific to BERT.
|
||||
@tf.function(
|
||||
input_signature=[
|
||||
[
|
||||
tf.TensorSpec([None, None], tf.int32, name="input_ids"),
|
||||
tf.TensorSpec([None, None], tf.int32, name="token_type_ids"),
|
||||
tf.TensorSpec([None, None], tf.int32, name="attention_mask"),
|
||||
]
|
||||
]
|
||||
)
|
||||
def serving_fn(input):
|
||||
return model(input)
|
||||
|
||||
# Using default signature (default behavior) overrides 'serving_default'
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, saved_model=True, signatures=None)
|
||||
model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("serving_default" in list(model_loaded.signatures.keys()))
|
||||
|
||||
# Providing custom signature function
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, saved_model=True, signatures={"custom_signature": serving_fn})
|
||||
model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("custom_signature" in list(model_loaded.signatures.keys()))
|
||||
|
||||
# Providing multiple custom signature function
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(
|
||||
tmp_dir,
|
||||
saved_model=True,
|
||||
signatures={"custom_signature_1": serving_fn, "custom_signature_2": serving_fn},
|
||||
)
|
||||
model_loaded = keras.models.load_model(f"{tmp_dir}/saved_model/1")
|
||||
self.assertTrue("custom_signature_1" in list(model_loaded.signatures.keys()))
|
||||
self.assertTrue("custom_signature_2" in list(model_loaded.signatures.keys()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# No tf_model.h5 file, only a model.safetensors
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_sharded_save_and_load(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# No tf weights or index file, only a safetensors index
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_NAME)))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
self.assertFalse(os.path.isfile(os.path.join(tmp_dir, TF2_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
# Check we have a model.safetensors file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_sharded_safetensors_save_and_load_pt_to_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
pt_model = BertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pt_model.save_pretrained(tmp_dir, safe_serialization=True, max_shard_size="150kB")
|
||||
# Check we have a safetensors shard index file
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, SAFE_WEIGHTS_INDEX_NAME)))
|
||||
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub(self):
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Can load from the TF-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors-tf")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-random-bert-safetensors")
|
||||
|
||||
# Check models are equal
|
||||
for p1, p2 in zip(safetensors_model.weights, tf_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_tf(self):
|
||||
model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
@is_pt_tf_cross_test
|
||||
def test_safetensors_tf_from_torch(self):
|
||||
hub_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-only")
|
||||
model = BertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-only")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, safe_serialization=True)
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
for p1, p2 in zip(hub_model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_local(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
path = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", cache_dir=tmp_dir)
|
||||
|
||||
# This should not raise even if there are two types of sharded weights
|
||||
TFBertModel.from_pretrained(path)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_tf_from_sharded_h5_with_sharded_safetensors_hub(self):
|
||||
# Confirm that we can correctly load the safetensors weights from a sharded hub repo even when TF weights present
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=True)
|
||||
# Confirm that we can access the TF weights too
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-tf-safetensors-h5-sharded", use_safetensors=False)
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-only", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-tf-safetensors-only", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a checkpoint that only has those on the Hub.
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
tf_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-h5")
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
safetensors_model = TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors")
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_from_safetensors_pt(self):
|
||||
"""
|
||||
This test checks that we can load safetensors from a local checkpoint that only has those
|
||||
saved in the "pt" format.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-h5", cache_dir=tmp)
|
||||
tf_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
# Can load from the PyTorch-formatted checkpoint
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors", cache_dir=tmp)
|
||||
safetensors_model = TFBertModel.from_pretrained(location)
|
||||
|
||||
for p1, p2 in zip(tf_model.weights, safetensors_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_hub_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
TFBertModel.from_pretrained("hf-internal-testing/tiny-bert-pt-safetensors-msgpack")
|
||||
|
||||
@require_safetensors
|
||||
def test_safetensors_load_from_local_h5_before_safetensors(self):
|
||||
"""
|
||||
This test checks that we'll first download h5 weights before safetensors
|
||||
The safetensors file on that repo is a pt safetensors and therefore cannot be loaded without PyTorch
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmp:
|
||||
location = snapshot_download("hf-internal-testing/tiny-bert-pt-safetensors-msgpack", cache_dir=tmp)
|
||||
TFBertModel.from_pretrained(location)
|
||||
|
||||
|
||||
@require_tf
|
||||
@is_staging_test
|
||||
class TFModelPushToHubTester(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-model-tf")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-model-tf-callback")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-model-tf-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
model.build_in_name_scope()
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.utils.hub")
|
||||
with CaptureLogger(logger) as cl:
|
||||
model.push_to_hub("test-model-tf", token=self._token)
|
||||
logging.set_verbosity_warning()
|
||||
# Check the model card was created and uploaded.
|
||||
self.assertIn("Uploading the following files to __DUMMY_TRANSFORMERS_USER__/test-model-tf", cl.out)
|
||||
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-model-tf")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, repo_id="test-model-tf", push_to_hub=True, token=self._token)
|
||||
|
||||
new_model = TFBertModel.from_pretrained(f"{USER}/test-model-tf")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_push_to_hub_callback(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertForMaskedLM(config)
|
||||
model.compile()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
push_to_hub_callback = PushToHubCallback(
|
||||
output_dir=tmp_dir,
|
||||
hub_model_id="test-model-tf-callback",
|
||||
hub_token=self._token,
|
||||
)
|
||||
model.fit(model.dummy_inputs, model.dummy_inputs, epochs=1, callbacks=[push_to_hub_callback])
|
||||
|
||||
new_model = TFBertForMaskedLM.from_pretrained(f"{USER}/test-model-tf-callback")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
tf_push_to_hub_params = dict(inspect.signature(TFPreTrainedModel.push_to_hub).parameters)
|
||||
tf_push_to_hub_params.pop("base_model_card_args")
|
||||
pt_push_to_hub_params = dict(inspect.signature(PreTrainedModel.push_to_hub).parameters)
|
||||
pt_push_to_hub_params.pop("deprecated_kwargs")
|
||||
self.assertDictEaual(tf_push_to_hub_params, pt_push_to_hub_params)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
config = BertConfig(
|
||||
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
model.build_in_name_scope()
|
||||
|
||||
model.push_to_hub("valid_org/test-model-tf-org", token=self._token)
|
||||
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-model-tf-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir, push_to_hub=True, token=self._token, repo_id="valid_org/test-model-tf-org")
|
||||
|
||||
new_model = TFBertModel.from_pretrained("valid_org/test-model-tf-org")
|
||||
models_equal = True
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
if not tf.math.reduce_all(p1 == p2):
|
||||
models_equal = False
|
||||
break
|
||||
self.assertTrue(models_equal)
|
||||
2282
tests/utils/test_modeling_utils.py
Normal file
2282
tests/utils/test_modeling_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
308
tests/utils/test_tokenization_utils.py
Normal file
308
tests/utils/test_tokenization_utils.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2019 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock as mock
|
||||
from pathlib import Path
|
||||
|
||||
from huggingface_hub import HfFolder, delete_repo
|
||||
from huggingface_hub.file_download import http_get
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from transformers import (
|
||||
AlbertTokenizer,
|
||||
AutoTokenizer,
|
||||
BertTokenizer,
|
||||
BertTokenizerFast,
|
||||
GPT2TokenizerFast,
|
||||
is_tokenizers_available,
|
||||
)
|
||||
from transformers.testing_utils import TOKEN, USER, is_staging_test, require_tokenizers
|
||||
from transformers.tokenization_utils import ExtensionsTrie, Trie
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_tokenization import CustomTokenizer # noqa E402
|
||||
|
||||
|
||||
if is_tokenizers_available():
|
||||
from test_module.custom_tokenization_fast import CustomTokenizerFast
|
||||
|
||||
|
||||
class TokenizerUtilTester(unittest.TestCase):
|
||||
def test_cached_files_are_used_when_internet_is_down(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = BertTokenizer.from_pretrained("hf-internal-testing/tiny-random-bert")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
@require_tokenizers
|
||||
def test_cached_files_are_used_when_internet_is_down_missing_files(self):
|
||||
# A mock response for an HTTP head request to emulate server down
|
||||
response_mock = mock.Mock()
|
||||
response_mock.status_code = 500
|
||||
response_mock.headers = {}
|
||||
response_mock.raise_for_status.side_effect = HTTPError
|
||||
response_mock.json.return_value = {}
|
||||
|
||||
# Download this model to make sure it's in the cache.
|
||||
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
|
||||
# Under the mock environment we get a 500 error when trying to reach the tokenizer.
|
||||
with mock.patch("requests.Session.request", return_value=response_mock) as mock_head:
|
||||
_ = GPT2TokenizerFast.from_pretrained("openai-community/gpt2")
|
||||
# This check we did call the fake head request
|
||||
mock_head.assert_called()
|
||||
|
||||
def test_legacy_load_from_one_file(self):
|
||||
# This test is for deprecated behavior and can be removed in v5
|
||||
try:
|
||||
tmp_file = tempfile.mktemp()
|
||||
with open(tmp_file, "wb") as f:
|
||||
http_get("https://huggingface.co/albert/albert-base-v1/resolve/main/spiece.model", f)
|
||||
|
||||
_ = AlbertTokenizer.from_pretrained(tmp_file)
|
||||
finally:
|
||||
os.remove(tmp_file)
|
||||
|
||||
# Supporting this legacy load introduced a weird bug where the tokenizer would load local files if they are in
|
||||
# the current folder and have the right name.
|
||||
if os.path.isfile("tokenizer.json"):
|
||||
# We skip the test if the user has a `tokenizer.json` in this folder to avoid deleting it.
|
||||
self.skipTest(reason="Skipping test as there is a `tokenizer.json` file in the current folder.")
|
||||
try:
|
||||
with open("tokenizer.json", "wb") as f:
|
||||
http_get("https://huggingface.co/hf-internal-testing/tiny-random-bert/blob/main/tokenizer.json", f)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-gpt2")
|
||||
# The tiny random BERT has a vocab size of 1024, tiny openai-community/gpt2 as a vocab size of 1000
|
||||
self.assertEqual(tokenizer.vocab_size, 1000)
|
||||
# Tokenizer should depend on the remote checkpoint, not the local tokenizer.json file.
|
||||
|
||||
finally:
|
||||
os.remove("tokenizer.json")
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "[PAD]", "[MASK]", "bla", "blou"]
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls._token = TOKEN
|
||||
HfFolder.save_token(TOKEN)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-tokenizer")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="valid_org/test-tokenizer-org")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
try:
|
||||
delete_repo(token=cls._token, repo_id="test-dynamic-tokenizer")
|
||||
except HTTPError:
|
||||
pass
|
||||
|
||||
def test_push_to_hub(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokenizer.push_to_hub("test-tokenizer", token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="test-tokenizer")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(tmp_dir, repo_id="test-tokenizer", push_to_hub=True, token=self._token)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained(f"{USER}/test-tokenizer")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
def test_push_to_hub_in_organization(self):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = BertTokenizer(vocab_file)
|
||||
|
||||
tokenizer.push_to_hub("valid_org/test-tokenizer-org", token=self._token)
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
# Reset repo
|
||||
delete_repo(token=self._token, repo_id="valid_org/test-tokenizer-org")
|
||||
|
||||
# Push to hub via save_pretrained
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
tokenizer.save_pretrained(
|
||||
tmp_dir, repo_id="valid_org/test-tokenizer-org", push_to_hub=True, token=self._token
|
||||
)
|
||||
|
||||
new_tokenizer = BertTokenizer.from_pretrained("valid_org/test-tokenizer-org")
|
||||
self.assertDictEqual(new_tokenizer.vocab, tokenizer.vocab)
|
||||
|
||||
@require_tokenizers
|
||||
def test_push_to_hub_dynamic_tokenizer(self):
|
||||
CustomTokenizer.register_for_auto_class()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
tokenizer = CustomTokenizer(vocab_file)
|
||||
|
||||
# No fast custom tokenizer
|
||||
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the CustomTokenizer class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
||||
|
||||
# Fast and slow custom tokenizer
|
||||
CustomTokenizerFast.register_for_auto_class()
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
vocab_file = os.path.join(tmp_dir, "vocab.txt")
|
||||
with open(vocab_file, "w", encoding="utf-8") as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in self.vocab_tokens]))
|
||||
|
||||
bert_tokenizer = BertTokenizerFast.from_pretrained(tmp_dir)
|
||||
bert_tokenizer.save_pretrained(tmp_dir)
|
||||
tokenizer = CustomTokenizerFast.from_pretrained(tmp_dir)
|
||||
|
||||
tokenizer.push_to_hub("test-dynamic-tokenizer", token=self._token)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{USER}/test-dynamic-tokenizer", trust_remote_code=True)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizerFast")
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
f"{USER}/test-dynamic-tokenizer", use_fast=False, trust_remote_code=True
|
||||
)
|
||||
# Can't make an isinstance check because the new_model.config is from the FakeConfig class of a dynamic module
|
||||
self.assertEqual(tokenizer.__class__.__name__, "CustomTokenizer")
|
||||
|
||||
|
||||
class TrieTest(unittest.TestCase):
|
||||
def test_trie(self):
|
||||
trie = Trie()
|
||||
trie.add("Hello 友達")
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {" ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
trie.add("Hello")
|
||||
trie.data
|
||||
self.assertEqual(trie.data, {"H": {"e": {"l": {"l": {"o": {"": 1, " ": {"友": {"達": {"": 1}}}}}}}}})
|
||||
|
||||
def test_trie_split(self):
|
||||
trie = Trie()
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS] This is a extra_id_100"])
|
||||
trie.add("[CLS]")
|
||||
trie.add("extra_id_1")
|
||||
trie.add("extra_id_100")
|
||||
self.assertEqual(trie.split("[CLS] This is a extra_id_100"), ["[CLS]", " This is a ", "extra_id_100"])
|
||||
|
||||
def test_trie_single(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
self.assertEqual(trie.split("ABC"), ["A", "BC"])
|
||||
self.assertEqual(trie.split("BCA"), ["BC", "A"])
|
||||
|
||||
def test_trie_final(self):
|
||||
trie = Trie()
|
||||
trie.add("TOKEN]")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_subtokens(self):
|
||||
trie = Trie()
|
||||
trie.add("A")
|
||||
trie.add("P")
|
||||
trie.add("[SPECIAL_TOKEN]")
|
||||
self.assertEqual(trie.split("This is something [SPECIAL_TOKEN]"), ["This is something ", "[SPECIAL_TOKEN]"])
|
||||
|
||||
def test_trie_suffix_tokens(self):
|
||||
trie = Trie()
|
||||
trie.add("AB")
|
||||
trie.add("B")
|
||||
trie.add("C")
|
||||
self.assertEqual(trie.split("ABC"), ["AB", "C"])
|
||||
|
||||
def test_trie_skip(self):
|
||||
trie = Trie()
|
||||
trie.add("ABC")
|
||||
trie.add("B")
|
||||
trie.add("CD")
|
||||
self.assertEqual(trie.split("ABCD"), ["ABC", "D"])
|
||||
|
||||
def test_cut_text_hardening(self):
|
||||
# Even if the offsets are wrong, we necessarily output correct string
|
||||
# parts.
|
||||
trie = Trie()
|
||||
parts = trie.cut_text("ABC", [0, 0, 2, 1, 2, 3])
|
||||
self.assertEqual(parts, ["AB", "C"])
|
||||
|
||||
|
||||
class ExtensionsTrieTest(unittest.TestCase):
|
||||
def test_extensions(self):
|
||||
# Test searching by prefix
|
||||
trie = ExtensionsTrie()
|
||||
trie.add("foo")
|
||||
trie.add("food")
|
||||
trie.add("foodie")
|
||||
trie.add("helium")
|
||||
self.assertEqual(trie.extensions("foo"), ["foo", "food", "foodie"])
|
||||
self.assertEqual(trie.extensions("helium"), ["helium"])
|
||||
|
||||
def test_empty_prefix(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching with an empty prefix returns all values
|
||||
trie.add("hello")
|
||||
trie.add("bye")
|
||||
self.assertEqual(trie.extensions(""), ["hello", "bye"])
|
||||
|
||||
def test_no_extension_match(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test searching for a prefix that doesn't match any key
|
||||
with self.assertRaises(KeyError):
|
||||
trie.extensions("unknown")
|
||||
|
||||
def test_update_value(self):
|
||||
trie = ExtensionsTrie()
|
||||
# Test updating the value of an existing key
|
||||
trie.add("hi")
|
||||
trie.add("hi")
|
||||
self.assertEqual(trie.extensions("hi"), ["hi"])
|
||||
Reference in New Issue
Block a user