Make TF pt-tf equivalence test more aggressive (#15839)
* Make TF pt-tf equivalence test more aggressive * Fix for TFConvNextModelTest and TFTransfoXLModelTest * fix kwargs for outputs * clean-up * Add docstring for check_outputs() * remove: need to rename encoder-decoder * clean-up * send PyTorch things to the correct device * Add back the accidentally removed test case in test_pt_tf_model_equivalence() * Fix: change to tuple before calling check_outputs() * Fix: tfo could be a list * use to_tuple() * allow tfo only to be tuple or tensor * allow tfo to be list or tuple for now + style change * minor fix * remove np.copy and update comments * tfo -> tf_output, same for pt * Add more detailed comment * remove the incorrect comment Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -39,10 +39,14 @@ from transformers.testing_utils import (
|
|||||||
require_tf,
|
require_tf,
|
||||||
require_tf2onnx,
|
require_tf2onnx,
|
||||||
slow,
|
slow,
|
||||||
|
torch_device,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
@@ -348,27 +352,10 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
import transformers
|
import transformers
|
||||||
|
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
def prepare_pt_inputs_from_tf_inputs(tf_inputs_dict):
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
|
||||||
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
|
||||||
pt_model_class = getattr(transformers, pt_model_class_name)
|
|
||||||
|
|
||||||
config.output_hidden_states = True
|
|
||||||
|
|
||||||
tf_model = model_class(config)
|
|
||||||
pt_model = pt_model_class(config)
|
|
||||||
|
|
||||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
|
||||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
|
||||||
tf_model, pt_model, tf_inputs=self._prepare_for_class(inputs_dict, model_class)
|
|
||||||
)
|
|
||||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
|
||||||
|
|
||||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
|
||||||
pt_model.eval()
|
|
||||||
pt_inputs_dict = {}
|
pt_inputs_dict = {}
|
||||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
for name, key in tf_inputs_dict.items():
|
||||||
if type(key) == bool:
|
if type(key) == bool:
|
||||||
pt_inputs_dict[name] = key
|
pt_inputs_dict[name] = key
|
||||||
elif name == "input_values":
|
elif name == "input_values":
|
||||||
@@ -380,23 +367,217 @@ class TFModelTesterMixin:
|
|||||||
else:
|
else:
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
||||||
|
|
||||||
|
return pt_inputs_dict
|
||||||
|
|
||||||
|
def check_outputs(tf_outputs, pt_outputs, model_class, names):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||||
|
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
|
||||||
|
debugging easier and faster.
|
||||||
|
|
||||||
|
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
|
||||||
|
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||||
|
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
|
||||||
|
if names == "past_key_values":
|
||||||
|
return
|
||||||
|
|
||||||
|
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
|
||||||
|
if type(tf_outputs) in [tuple, list]:
|
||||||
|
self.assertEqual(type(tf_outputs), type(pt_outputs))
|
||||||
|
self.assertEqual(len(tf_outputs), len(pt_outputs))
|
||||||
|
if type(names) == tuple:
|
||||||
|
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
|
||||||
|
check_outputs(tf_output, pt_output, model_class, names=name)
|
||||||
|
elif type(names) == str:
|
||||||
|
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
|
||||||
|
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||||
|
elif isinstance(tf_outputs, tf.Tensor):
|
||||||
|
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||||
|
|
||||||
|
tf_outputs = tf_outputs.numpy()
|
||||||
|
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||||
|
|
||||||
|
tf_nans = np.isnan(tf_outputs)
|
||||||
|
pt_nans = np.isnan(pt_outputs)
|
||||||
|
|
||||||
|
pt_outputs[tf_nans] = 0
|
||||||
|
tf_outputs[tf_nans] = 0
|
||||||
|
pt_outputs[pt_nans] = 0
|
||||||
|
tf_outputs[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_pt_tf_models(tf_model, pt_model):
|
||||||
|
|
||||||
|
# send pytorch model to the correct device
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
|
||||||
|
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||||
|
pt_model.eval()
|
||||||
|
|
||||||
|
pt_inputs_dict = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict)
|
||||||
|
pt_inputs_dict_maybe_with_labels = prepare_pt_inputs_from_tf_inputs(tf_inputs_dict_maybe_with_labels)
|
||||||
|
|
||||||
|
# send pytorch inputs to the correct device
|
||||||
|
pt_inputs_dict = {
|
||||||
|
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||||
|
}
|
||||||
|
pt_inputs_dict_maybe_with_labels = {
|
||||||
|
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
|
||||||
|
for k, v in pt_inputs_dict_maybe_with_labels.items()
|
||||||
|
}
|
||||||
|
|
||||||
|
# Original test: check without `labels`
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pto = pt_model(**pt_inputs_dict)
|
pt_outputs = pt_model(**pt_inputs_dict)
|
||||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class), training=False)
|
tf_outputs = tf_model(tf_inputs_dict)
|
||||||
|
|
||||||
tf_hidden_states = tfo[0].numpy()
|
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||||
pt_hidden_states = pto[0].numpy()
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
tf_nans = np.copy(np.isnan(tf_hidden_states))
|
self.assertEqual(tf_keys, pt_keys)
|
||||||
pt_nans = np.copy(np.isnan(pt_hidden_states))
|
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
|
||||||
|
|
||||||
pt_hidden_states[tf_nans] = 0
|
# check the case where `labels` is passed
|
||||||
tf_hidden_states[tf_nans] = 0
|
has_labels = any(
|
||||||
pt_hidden_states[pt_nans] = 0
|
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
|
||||||
tf_hidden_states[pt_nans] = 0
|
)
|
||||||
|
if has_labels:
|
||||||
|
|
||||||
max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states))
|
with torch.no_grad():
|
||||||
self.assertLessEqual(max_diff, 4e-2)
|
pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels)
|
||||||
|
tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels)
|
||||||
|
|
||||||
|
# Some models' output class don't have `loss` attribute despite `labels` is used.
|
||||||
|
# TODO: identify which models
|
||||||
|
tf_loss = getattr(tf_outputs, "loss", None)
|
||||||
|
pt_loss = getattr(pt_outputs, "loss", None)
|
||||||
|
|
||||||
|
# Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
|
||||||
|
# - TFFlaubertWithLMHeadModel
|
||||||
|
# - TFFunnelForPreTraining
|
||||||
|
# - TFElectraForPreTraining
|
||||||
|
# - TFXLMWithLMHeadModel
|
||||||
|
# TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
|
||||||
|
if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)):
|
||||||
|
if model_class.__name__ not in [
|
||||||
|
"TFFlaubertWithLMHeadModel",
|
||||||
|
"TFFunnelForPreTraining",
|
||||||
|
"TFElectraForPreTraining",
|
||||||
|
"TFXLMWithLMHeadModel",
|
||||||
|
]:
|
||||||
|
self.assertEqual(tf_loss is None, pt_loss is None)
|
||||||
|
|
||||||
|
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
|
# TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
|
||||||
|
# (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
|
||||||
|
if tf_keys != pt_keys:
|
||||||
|
if model_class.__name__ not in [
|
||||||
|
"TFFlaubertWithLMHeadModel",
|
||||||
|
"TFFunnelForPreTraining",
|
||||||
|
"TFElectraForPreTraining",
|
||||||
|
"TFXLMWithLMHeadModel",
|
||||||
|
] + ["TFTransfoXLLMHeadModel"]:
|
||||||
|
self.assertEqual(tf_keys, pt_keys)
|
||||||
|
|
||||||
|
# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
|
||||||
|
# some remaining attributes in the outputs.
|
||||||
|
# TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
|
||||||
|
# compute the 1st `index` where `tf_keys` and `pt_keys` is different
|
||||||
|
index = 0
|
||||||
|
for _ in range(min(len(tf_keys), len(pt_keys))):
|
||||||
|
if tf_keys[index] == pt_keys[index]:
|
||||||
|
index += 1
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
if tf_keys[:index] != pt_keys[:index]:
|
||||||
|
self.assertEqual(tf_keys, pt_keys)
|
||||||
|
|
||||||
|
# Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
|
||||||
|
# both`labels` and `next_sentence_label`.
|
||||||
|
if tf_loss is not None and pt_loss is not None:
|
||||||
|
|
||||||
|
# check anything else than `loss`
|
||||||
|
keys = tuple([k for k in tf_keys])
|
||||||
|
check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index])
|
||||||
|
|
||||||
|
# check `loss`
|
||||||
|
|
||||||
|
# tf models returned loss is usually a tensor rather than a scalar.
|
||||||
|
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
|
||||||
|
# Change it here to a scalar to match PyTorch models' loss
|
||||||
|
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
|
||||||
|
pt_loss = pt_loss.detach().to("cpu").numpy()
|
||||||
|
|
||||||
|
tf_nans = np.isnan(tf_loss)
|
||||||
|
pt_nans = np.isnan(pt_loss)
|
||||||
|
# the 2 losses need to be both nan or both not nan
|
||||||
|
self.assertEqual(tf_nans, pt_nans)
|
||||||
|
|
||||||
|
if not tf_nans:
|
||||||
|
max_diff = np.amax(np.abs(tf_loss - pt_loss))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
# Output all for aggressive testing
|
||||||
|
config.output_hidden_states = True
|
||||||
|
# Pure convolutional models have no attention
|
||||||
|
# TODO: use a better and general criteria
|
||||||
|
if "TFConvNext" not in model_class.__name__:
|
||||||
|
config.output_attentions = True
|
||||||
|
|
||||||
|
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
||||||
|
if k in inputs_dict:
|
||||||
|
attention_mask = inputs_dict[k]
|
||||||
|
# make sure no all 0s attention masks - to avoid failure at this moment.
|
||||||
|
# TODO: remove this line once the TODO below is implemented.
|
||||||
|
attention_mask = tf.ones_like(attention_mask, dtype=tf.int32)
|
||||||
|
# Here we make the first sequence with all 0s as attention mask.
|
||||||
|
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
|
||||||
|
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
|
||||||
|
# TODO: enable this block once the large negative values thing is cleaned up.
|
||||||
|
# (see https://github.com/huggingface/transformers/issues/14859)
|
||||||
|
# attention_mask = tf.concat(
|
||||||
|
# [
|
||||||
|
# tf.zeros_like(attention_mask[:1], dtype=tf.int32),
|
||||||
|
# tf.cast(attention_mask[1:], dtype=tf.int32)
|
||||||
|
# ],
|
||||||
|
# axis=0
|
||||||
|
# )
|
||||||
|
inputs_dict[k] = attention_mask
|
||||||
|
|
||||||
|
pt_model_class_name = model_class.__name__[2:] # Skip the "TF" at the beginning
|
||||||
|
pt_model_class = getattr(transformers, pt_model_class_name)
|
||||||
|
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
tf_model = model_class(config)
|
||||||
|
pt_model = pt_model_class(config)
|
||||||
|
|
||||||
|
tf_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
tf_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|
||||||
|
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||||
|
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||||
|
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||||
|
|
||||||
|
check_pt_tf_models(tf_model, pt_model)
|
||||||
|
|
||||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
@@ -408,37 +589,7 @@ class TFModelTesterMixin:
|
|||||||
tf_model.save_weights(tf_checkpoint_path)
|
tf_model.save_weights(tf_checkpoint_path)
|
||||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||||
|
|
||||||
# Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences
|
check_pt_tf_models(tf_model, pt_model)
|
||||||
pt_model.eval()
|
|
||||||
pt_inputs_dict = {}
|
|
||||||
for name, key in self._prepare_for_class(inputs_dict, model_class).items():
|
|
||||||
if type(key) == bool:
|
|
||||||
key = np.array(key, dtype=bool)
|
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key).to(torch.long)
|
|
||||||
elif name == "input_values":
|
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
|
||||||
elif name == "pixel_values":
|
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
|
||||||
elif name == "input_features":
|
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.float32)
|
|
||||||
else:
|
|
||||||
pt_inputs_dict[name] = torch.from_numpy(key.numpy()).to(torch.long)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
pto = pt_model(**pt_inputs_dict)
|
|
||||||
tfo = tf_model(self._prepare_for_class(inputs_dict, model_class))
|
|
||||||
tfo = tfo[0].numpy()
|
|
||||||
pto = pto[0].numpy()
|
|
||||||
tf_nans = np.copy(np.isnan(tfo))
|
|
||||||
pt_nans = np.copy(np.isnan(pto))
|
|
||||||
|
|
||||||
pto[tf_nans] = 0
|
|
||||||
tfo[tf_nans] = 0
|
|
||||||
pto[pt_nans] = 0
|
|
||||||
tfo[pt_nans] = 0
|
|
||||||
|
|
||||||
max_diff = np.amax(np.abs(tfo - pto))
|
|
||||||
self.assertLessEqual(max_diff, 4e-2)
|
|
||||||
|
|
||||||
def test_compile_tf_model(self):
|
def test_compile_tf_model(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user