Cleaning TensorFlow models (#5229)

* Cleaning TensorFlow models

Update all classes


stylr

* Don't average loss
This commit is contained in:
Lysandre Debut
2020-06-24 11:37:20 -04:00
committed by GitHub
parent 609e0c583f
commit cf10d4cfdd
13 changed files with 483 additions and 126 deletions

View File

@@ -15,6 +15,7 @@
import copy
import inspect
import os
import random
import tempfile
@@ -35,6 +36,9 @@ if is_tf_available():
TFAdaptiveEmbedding,
TFSharedEmbeddings,
TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
)
if _tf_gpu_memory_limit is not None:
@@ -71,14 +75,25 @@ class TFModelTesterMixin:
test_resize_embeddings = True
is_encoder_decoder = False
def _prepare_for_class(self, inputs_dict, model_class):
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
return {
inputs_dict = {
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices, 1))
if isinstance(v, tf.Tensor) and v.ndim != 0
else v
for k, v in inputs_dict.items()
}
if return_labels:
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
return inputs_dict
def test_initialization(self):
@@ -572,6 +587,51 @@ class TFModelTesterMixin:
generated_ids = output_tokens[:, input_ids.shape[-1] :]
self.assertFalse(self._check_match_tokens(generated_ids.numpy().tolist(), bad_words_ids))
def test_loss_computation(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
if getattr(model, "compute_loss", None):
# The number of elements in the loss should be the same as the number of elements in the label
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
added_label = prepared_for_class[list(prepared_for_class.keys() - inputs_dict.keys())[0]]
loss_size = tf.size(added_label)
# Test that model correctly compute the loss with kwargs
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
input_ids = prepared_for_class.pop("input_ids")
loss = model(input_ids, **prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a dict
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
loss = model(prepared_for_class)[0]
self.assertEqual(loss.shape, [loss_size])
# Test that model correctly compute the loss with a tuple
prepared_for_class = self._prepare_for_class(inputs_dict.copy(), model_class, return_labels=True)
# Get keys that were added with the _prepare_for_class function
label_keys = prepared_for_class.keys() - inputs_dict.keys()
signature = inspect.getfullargspec(model.call)[0]
# Create a dictionary holding the location of the tensors in the tuple
tuple_index_mapping = {1: "input_ids"}
for label_key in label_keys:
label_key_index = signature.index(label_key)
tuple_index_mapping[label_key_index] = label_key
sorted_tuple_index_mapping = sorted(tuple_index_mapping.items())
# Initialize a list with None, update the values and convert to a tuple
list_input = [None] * sorted_tuple_index_mapping[-1][0]
for index, value in sorted_tuple_index_mapping:
list_input[index - 1] = prepared_for_class[value]
tuple_input = tuple(list_input)
# Send to model
loss = model(tuple_input)[0]
self.assertEqual(loss.shape, [loss_size])
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []