Test model outputs equivalence (#6445)
* Test model outputs equivalence * Fix failing tests * From dict to kwargs * DistilBERT * Addressing @sgugger and @patrickvonplaten's comments
This commit is contained in:
@@ -18,7 +18,7 @@ import os.path
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_torch_available
|
||||
from transformers.testing_utils import require_multigpu, require_torch, slow, torch_device
|
||||
@@ -37,6 +37,11 @@ if is_torch_available():
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_MULTIPLE_CHOICE_MAPPING,
|
||||
MODEL_FOR_QUESTION_ANSWERING_MAPPING,
|
||||
MODEL_FOR_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_MASKED_LM_MAPPING,
|
||||
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING,
|
||||
MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
|
||||
MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
|
||||
top_k_top_p_filtering,
|
||||
)
|
||||
|
||||
@@ -63,14 +68,39 @@ class ModelTesterMixin:
|
||||
test_chunking = False
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class):
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
return {
|
||||
inputs_dict = {
|
||||
k: v.unsqueeze(1).expand(-1, self.model_tester.num_choices, -1).contiguous()
|
||||
if isinstance(v, torch.Tensor) and v.ndim > 1
|
||||
else v
|
||||
for k, v in inputs_dict.items()
|
||||
}
|
||||
|
||||
if return_labels:
|
||||
if model_class in MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.ones(self.model_tester.batch_size, dtype=torch.long, device=torch_device)
|
||||
elif model_class in MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
inputs_dict["start_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
inputs_dict["end_positions"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
self.model_tester.batch_size, dtype=torch.long, device=torch_device
|
||||
)
|
||||
elif model_class in [
|
||||
*MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
]:
|
||||
inputs_dict["labels"] = torch.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=torch.long, device=torch_device
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def test_save_load(self):
|
||||
@@ -663,6 +693,64 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
with torch.no_grad():
|
||||
tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(**dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(tuple_object, dict_object, atol=1e-5),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {torch.max(torch.abs(tuple_object - dict_object))}",
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(
|
||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||
)
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -37,6 +37,8 @@ class T5ModelTester:
|
||||
self.batch_size = 13
|
||||
self.encoder_seq_length = 7
|
||||
self.decoder_seq_length = 9
|
||||
# For common tests
|
||||
self.seq_length = self.decoder_seq_length
|
||||
self.is_training = True
|
||||
self.use_attention_mask = True
|
||||
self.use_labels = True
|
||||
|
||||
@@ -21,6 +21,7 @@ import random
|
||||
import tempfile
|
||||
import unittest
|
||||
from importlib import import_module
|
||||
from typing import List, Tuple
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||
@@ -78,6 +79,8 @@ class TFModelTesterMixin:
|
||||
is_encoder_decoder = False
|
||||
|
||||
def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
|
||||
inputs_dict = copy.deepcopy(inputs_dict)
|
||||
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
inputs_dict = {
|
||||
k: tf.tile(tf.expand_dims(v, 1), (1, self.model_tester.num_choices) + (1,) * (v.ndim - 1))
|
||||
@@ -88,20 +91,21 @@ class TFModelTesterMixin:
|
||||
|
||||
if return_labels:
|
||||
if model_class in TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size)
|
||||
inputs_dict["labels"] = tf.ones(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING.values():
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size)
|
||||
inputs_dict["start_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
inputs_dict["end_positions"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size)
|
||||
elif model_class in TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||
elif model_class in TF_MODEL_FOR_CAUSAL_LM_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||
elif model_class in TF_MODEL_FOR_MASKED_LM_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||
elif model_class in TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values():
|
||||
inputs_dict["labels"] = tf.zeros((self.model_tester.batch_size, self.model_tester.seq_length))
|
||||
inputs_dict["labels"] = tf.zeros(self.model_tester.batch_size, dtype=tf.int32)
|
||||
elif model_class in [
|
||||
*TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING.values(),
|
||||
*TF_MODEL_FOR_CAUSAL_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_MASKED_LM_MAPPING.values(),
|
||||
*TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.values(),
|
||||
]:
|
||||
inputs_dict["labels"] = tf.zeros(
|
||||
(self.model_tester.batch_size, self.model_tester.seq_length), dtype=tf.int32
|
||||
)
|
||||
return inputs_dict
|
||||
|
||||
def test_initialization(self):
|
||||
@@ -517,6 +521,61 @@ class TFModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
|
||||
tuple_output = model(tuple_inputs, return_dict=False, **additional_kwargs)
|
||||
dict_output = model(dict_inputs, return_dict=True, **additional_kwargs).to_tuple()
|
||||
|
||||
def recursive_check(tuple_object, dict_object):
|
||||
if isinstance(tuple_object, (List, Tuple)):
|
||||
for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
|
||||
recursive_check(tuple_iterable_value, dict_iterable_value)
|
||||
elif tuple_object is None:
|
||||
return
|
||||
else:
|
||||
self.assertTrue(
|
||||
all(tf.equal(tuple_object, dict_object)),
|
||||
msg=f"Tuple and dict output are not equal. Difference: {tf.math.reduce_max(tf.abs(tuple_object - dict_object))}",
|
||||
)
|
||||
|
||||
recursive_check(tuple_output, dict_output)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs)
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_attentions": True})
|
||||
|
||||
tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
check_equivalence(
|
||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||
)
|
||||
|
||||
def _get_embeds(self, wte, input_ids):
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try a few of them.
|
||||
|
||||
Reference in New Issue
Block a user