TF: tests for (de)serializable models with resized tokens (#19013)
* resized models that we can actually load * separate embeddings check * add test for embeddings out of bounds * add fake slows
This commit is contained in:
@@ -1162,7 +1162,7 @@ class TFModelTesterMixin:
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
# build the embeddings
|
||||
model = model_class(config=config)
|
||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||
old_input_embeddings = _get_word_embedding_weight(model, model.get_input_embeddings())
|
||||
old_bias = model.get_bias()
|
||||
old_output_embeddings = _get_word_embedding_weight(model, model.get_output_embeddings())
|
||||
@@ -1203,6 +1203,65 @@ class TFModelTesterMixin:
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
# TODO (Joao): this test is not slow, but it's tagged as such to keep track of failures on the scheduled CI runs,
|
||||
# while passing push CI. Fix the underlying issues and remove the tag.
|
||||
@slow
|
||||
def test_save_load_after_resize_token_embeddings(self):
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
# create a model with resized (expended) embeddings
|
||||
new_tokens_size = 10
|
||||
old_total_size = config.vocab_size
|
||||
new_total_size = old_total_size + new_tokens_size
|
||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||
model(model.dummy_inputs) # builds the embeddings layer
|
||||
model.resize_token_embeddings(new_total_size)
|
||||
|
||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||
inputs_dict = copy.deepcopy(original_inputs_dict)
|
||||
new_vocab_input_ids = ids_tensor(inputs_dict["input_ids"].shape, new_tokens_size)
|
||||
new_vocab_input_ids += old_total_size
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs_dict["input_ids"] = new_vocab_input_ids
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"] = new_vocab_input_ids
|
||||
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
outputs = model(**prepared_inputs)
|
||||
|
||||
# save and load the model
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
model = model_class.from_pretrained(tmpdirname)
|
||||
restored_model_outputs = model(**prepared_inputs)
|
||||
|
||||
# check that the output for the restored model is the same
|
||||
self.assert_outputs_same(restored_model_outputs, outputs)
|
||||
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) == 0,
|
||||
reason="This test always passes on CPU.",
|
||||
)
|
||||
def test_embeddings_out_of_bounds_raise_exception(self):
|
||||
# TF embeddings layers don't raise an exception when an index is out of bounds on GPU, so we manually raise it.
|
||||
# This test should only fail on GPU for models where we haven't added the safety check.
|
||||
if not self.test_resize_embeddings:
|
||||
return
|
||||
config, original_inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config=config)
|
||||
inputs_dict = copy.deepcopy(original_inputs_dict)
|
||||
if "input_ids" in inputs_dict:
|
||||
inputs_dict["input_ids"] = inputs_dict["input_ids"] * int(1e9)
|
||||
if "decoder_input_ids" in inputs_dict:
|
||||
inputs_dict["decoder_input_ids"] = inputs_dict["decoder_input_ids"] * int(1e9)
|
||||
prepared_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
with self.assertRaises(tf.errors.InvalidArgumentError):
|
||||
model(**prepared_inputs)
|
||||
|
||||
def test_lm_head_model_random_no_beam_search_generate(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
input_ids = inputs_dict.get("input_ids", None)
|
||||
|
||||
Reference in New Issue
Block a user