TF: tests for (de)serializable models with resized tokens (#19013)

* resized models that we can actually load

* separate embeddings check

* add test for embeddings out of bounds

* add fake slows
This commit is contained in:
Joao Gante
2022-09-16 16:38:08 +01:00
committed by GitHub
parent 70ba10e6d4
commit 658010c739
4 changed files with 123 additions and 4 deletions

View File

@@ -375,6 +375,7 @@ class TFViTMAEModelTest(TFModelTesterMixin, unittest.TestCase):
# overwrite from common since TFViTMAEForPretraining has random masking, we need to fix the noise
# to generate masks during test
@slow
def test_save_load(self):
# make mask reproducible
np.random.seed(2)

View File

@@ -1162,7 +1162,7 @@ class TFModelTesterMixin:
for model_class in self.all_model_classes:
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
# build the embeddings
model = model_class(config=config)
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
old_bias = model.get_bias()
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
@@ -1203,6 +1203,65 @@ class TFModelTesterMixin:
models_equal = False
self.assertTrue(models_equal)
# TODO (Joao): this test is not slow, but it's tagged as such to keep track of failures on the scheduled CI runs,
# while passing push CI. Fix the underlying issues and remove the tag.
@slow
def test_save_load_after_resize_token_embeddings(self):
if not self.test_resize_embeddings:
return
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
# create a model with resized (expended) embeddings
new_tokens_size = 10
old_total_size = config.vocab_size
new_total_size = old_total_size + new_tokens_size
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
model(model.dummy_inputs) # builds the embeddings layer
model.resize_token_embeddings(new_total_size)
# fetch the output for an input exclusively made of new members of the vocabulary
inputs_dict = copy.deepcopy(original_inputs_dict)
new_vocab_input_ids = ids_tensor(inputs_dict["input_ids"].shape, new_tokens_size)
new_vocab_input_ids += old_total_size
if "input_ids" in inputs_dict:
inputs_dict["input_ids"] = new_vocab_input_ids
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
outputs = model(**prepared_inputs)
# save and load the model
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname, saved_model=False)
model = model_class.from_pretrained(tmpdirname)
restored_model_outputs = model(**prepared_inputs)
# check that the output for the restored model is the same
self.assert_outputs_same(restored_model_outputs, outputs)
@unittest.skipIf(
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
reason="This test always passes on CPU.",
)
def test_embeddings_out_of_bounds_raise_exception(self):
# TF embeddings layers don't raise an exception when an index is out of bounds on GPU, so we manually raise it.
# This test should only fail on GPU for models where we haven't added the safety check.
if not self.test_resize_embeddings:
return
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config=config)
inputs_dict = copy.deepcopy(original_inputs_dict)
if "input_ids" in inputs_dict:
inputs_dict["input_ids"] = inputs_dict["input_ids"] * int(1e9)
if "decoder_input_ids" in inputs_dict:
inputs_dict["decoder_input_ids"] = inputs_dict["decoder_input_ids"] * int(1e9)
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
with self.assertRaises(tf.errors.InvalidArgumentError):
model(**prepared_inputs)
def test_lm_head_model_random_no_beam_search_generate(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
input_ids = inputs_dict.get("input_ids", None)