Improve test_pt_tf_model_equivalence on PT side (#16731)
* Update test_pt_tf_model_equivalence on PT side Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -56,7 +56,14 @@ from transformers.testing_utils import (
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import WEIGHTS_INDEX_NAME, WEIGHTS_NAME, is_flax_available, is_torch_fx_available
|
||||
from transformers.utils import (
|
||||
WEIGHTS_INDEX_NAME,
|
||||
WEIGHTS_NAME,
|
||||
is_flax_available,
|
||||
is_tf_available,
|
||||
is_torch_fx_available,
|
||||
)
|
||||
from transformers.utils.generic import ModelOutput
|
||||
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
@@ -94,6 +101,9 @@ if is_torch_available():
|
||||
)
|
||||
from transformers.modeling_utils import shard_checkpoint
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
if is_flax_available():
|
||||
import jax.numpy as jnp
|
||||
from transformers.modeling_flax_pytorch_utils import (
|
||||
@@ -1478,237 +1488,240 @@ class ModelTesterMixin:
|
||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||
)
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
# Don't copy this method to model specific test file!
|
||||
# TODO: remove this method once the issues are all fixed!
|
||||
def _make_attention_mask_non_null(self, inputs_dict):
|
||||
"""Make sure no sequence has all zeros as attention mask"""
|
||||
|
||||
import transformers
|
||||
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
||||
if k in inputs_dict:
|
||||
attention_mask = inputs_dict[k]
|
||||
|
||||
def prepare_tf_inputs_from_pt_inputs(pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs_dict.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
# To deal with the edge cases from `TFTapasForQuestionAnswering`.
|
||||
# PyTorch can deal with type casting automatically, but TensorFlow is more strict!
|
||||
# TODO: find a clean/better way to deal with these extra keys that are not common.
|
||||
elif key in ["float_answer", "numeric_values", "numeric_values_scale"]:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
|
||||
|
||||
return tf_inputs_dict
|
||||
|
||||
def check_outputs(tf_outputs, pt_outputs, model_class, names):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Currently unused, but it could make
|
||||
debugging easier and faster.
|
||||
|
||||
names: A string, or a tuple of strings. These specify what tf_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and TF.
|
||||
"""
|
||||
|
||||
# Some issue (`about past_key_values`) to solve (e.g. `TFPegasusForConditionalGeneration`) in a separate PR.
|
||||
if names == "past_key_values":
|
||||
return
|
||||
|
||||
# Allow `list` because `(TF)TransfoXLModelOutput.mems` is a list of tensors.
|
||||
if type(tf_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(tf_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for tf_output, pt_output, name in zip(tf_outputs, pt_outputs, names):
|
||||
check_outputs(tf_output, pt_output, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (tf_output, pt_output) in enumerate(zip(tf_outputs, pt_outputs)):
|
||||
check_outputs(tf_output, pt_output, model_class, names=f"{names}_{idx}")
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
elif isinstance(tf_outputs, tf.Tensor):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
|
||||
tf_outputs = tf_outputs.numpy()
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
tf_nans = np.isnan(tf_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[tf_nans] = 0
|
||||
tf_outputs[tf_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be a `tuple` or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
# Make sure no all 0s attention masks - to avoid failure at this moment.
|
||||
# Put `1` at the beginning of sequences to make it still work when combining causal attention masks.
|
||||
# TODO: remove this line once a fix regarding large negative values for attention mask is done.
|
||||
attention_mask = torch.cat(
|
||||
[torch.ones_like(attention_mask[:, :1], dtype=attention_mask.dtype), attention_mask[:, 1:]], dim=-1
|
||||
)
|
||||
|
||||
def check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels):
|
||||
# Here we make the first sequence with all 0s as attention mask.
|
||||
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
|
||||
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
|
||||
# TODO: enable this block once the large negative values thing is cleaned up.
|
||||
# (see https://github.com/huggingface/transformers/issues/14859)
|
||||
# attention_mask = torch.cat(
|
||||
# [torch.zeros_like(attention_mask[:1], dtype=attention_mask.dtype), attention_mask[1:]],
|
||||
# dim=0
|
||||
# )
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
inputs_dict[k] = attention_mask
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
# Don't copy this method to model specific test file!
|
||||
# TODO: remove this method once the issues are all fixed!
|
||||
def _postprocessing_to_ignore_test_cases(self, tf_outputs, pt_outputs, model_class):
|
||||
"""For temporarily ignoring some failed test cases (issues to be fixed)"""
|
||||
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_inputs_dict_maybe_with_labels = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict_maybe_with_labels)
|
||||
tf_keys = set([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = set([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs_dict = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||
}
|
||||
pt_inputs_dict_maybe_with_labels = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v
|
||||
for k, v in pt_inputs_dict_maybe_with_labels.items()
|
||||
}
|
||||
key_differences = tf_keys.symmetric_difference(pt_keys)
|
||||
|
||||
# Original test: check without `labels`
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict)
|
||||
tf_outputs = tf_model(tf_inputs_dict)
|
||||
if model_class.__name__ in [
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FunnelForPreTraining",
|
||||
"ElectraForPreTraining",
|
||||
"XLMWithLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
]:
|
||||
for k in key_differences:
|
||||
if k in ["loss", "losses"]:
|
||||
tf_keys.discard(k)
|
||||
pt_keys.discard(k)
|
||||
elif model_class.__name__.startswith("GPT2"):
|
||||
# `TFGPT2` has `past_key_values` as a tensor while `GPT2` has it as a tuple.
|
||||
tf_keys.discard("past_key_values")
|
||||
pt_keys.discard("past_key_values")
|
||||
|
||||
# create new outputs from the remaining fields
|
||||
new_tf_outputs = type(tf_outputs)(**{k: tf_outputs[k] for k in tf_keys})
|
||||
new_pt_outputs = type(pt_outputs)(**{k: pt_outputs[k] for k in pt_keys})
|
||||
|
||||
return new_tf_outputs, new_pt_outputs
|
||||
|
||||
# Copied from tests.test_modeling_tf_common.TFModelTesterMixin.check_pt_tf_outputs
|
||||
def check_pt_tf_outputs(self, tf_outputs, pt_outputs, model_class, tol=1e-5, name="outputs", attributes=None):
|
||||
"""Check the outputs from PyTorch and TensorFlow models are closed enough. Checks are done in a recursive way.
|
||||
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, `TFBertModel`,
|
||||
TFBertForMaskedLM`, `TFBertForSequenceClassification`, etc. Mainly used for providing more informative
|
||||
error messages.
|
||||
name (`str`): The name of the output. For example, `output.hidden_states`, `output.attentions`, etc.
|
||||
attributes (`Tuple[str]`): The names of the output's element if the output is a tuple/list with each element
|
||||
being a named field in the output.
|
||||
"""
|
||||
|
||||
self.assertEqual(type(name), str)
|
||||
if attributes is not None:
|
||||
self.assertEqual(type(attributes), tuple, f"{name}: The argument `attributes` should be a `tuple`")
|
||||
|
||||
# Allow `ModelOutput` (e.g. `CLIPOutput` has `text_model_output` and `vision_model_output`).
|
||||
if isinstance(tf_outputs, ModelOutput):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, ModelOutput),
|
||||
f"{name}: `pt_outputs` should an instance of `ModelOutput` when `tf_outputs` is",
|
||||
)
|
||||
|
||||
# Don't copy this block to model specific test file!
|
||||
# TODO: remove this method and this line after issues are fixed
|
||||
tf_outputs, pt_outputs = self._postprocessing_to_ignore_test_cases(tf_outputs, pt_outputs, model_class)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
check_outputs(tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=tf_keys)
|
||||
self.assertEqual(tf_keys, pt_keys, f"{name}: Output keys differ between TF and PyTorch")
|
||||
|
||||
# check the case where `labels` is passed
|
||||
has_labels = any(
|
||||
x in tf_inputs_dict_maybe_with_labels for x in ["labels", "next_sentence_label", "start_positions"]
|
||||
# convert to the case of `tuple`
|
||||
# appending each key to the current (string) `names`
|
||||
attributes = tuple([f"{name}.{k}" for k in tf_keys])
|
||||
self.check_pt_tf_outputs(
|
||||
tf_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, tol=tol, name=name, attributes=attributes
|
||||
)
|
||||
if has_labels:
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict_maybe_with_labels)
|
||||
tf_outputs = tf_model(tf_inputs_dict_maybe_with_labels)
|
||||
# Allow `list` (e.g. `TransfoXLModelOutput.mems` is a list of tensors.)
|
||||
elif type(tf_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(tf_outputs), type(pt_outputs), f"{name}: Output types differ between TF and PyTorch")
|
||||
self.assertEqual(len(tf_outputs), len(pt_outputs), f"{name}: Output lengths differ between TF and PyTorch")
|
||||
|
||||
# Some models' output class don't have `loss` attribute despite `labels` is used.
|
||||
# TODO: identify which models
|
||||
tf_loss = getattr(tf_outputs, "loss", None)
|
||||
pt_loss = getattr(pt_outputs, "loss", None)
|
||||
if attributes is not None:
|
||||
# case 1: each output has assigned name (e.g. a tuple form of a `ModelOutput`)
|
||||
self.assertEqual(
|
||||
len(attributes),
|
||||
len(tf_outputs),
|
||||
f"{name}: The tuple `names` should have the same length as `tf_outputs`",
|
||||
)
|
||||
else:
|
||||
# case 2: each output has no assigned name (e.g. hidden states of each layer) -> add an index to `names`
|
||||
attributes = tuple([f"{name}_{idx}" for idx in range(len(tf_outputs))])
|
||||
|
||||
# Some PT models return loss while the corresponding TF models don't (i.e. `None` for `loss`).
|
||||
# - FlaubertWithLMHeadModel
|
||||
# - FunnelForPreTraining
|
||||
# - ElectraForPreTraining
|
||||
# - XLMWithLMHeadModel
|
||||
# TODO: Fix PT/TF diff -> remove this condition to fail the test if a diff occurs
|
||||
if not ((tf_loss is None and pt_loss is None) or (tf_loss is not None and pt_loss is not None)):
|
||||
if model_class.__name__ not in [
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FunnelForPreTraining",
|
||||
"ElectraForPreTraining",
|
||||
"XLMWithLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
]:
|
||||
self.assertEqual(tf_loss is None, pt_loss is None)
|
||||
for tf_output, pt_output, attr in zip(tf_outputs, pt_outputs, attributes):
|
||||
self.check_pt_tf_outputs(tf_output, pt_output, model_class, tol=tol, name=attr)
|
||||
|
||||
tf_keys = tuple([k for k, v in tf_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
elif isinstance(tf_outputs, tf.Tensor):
|
||||
self.assertTrue(
|
||||
isinstance(pt_outputs, torch.Tensor), f"{name}: `pt_outputs` should a tensor when `tf_outputs` is"
|
||||
)
|
||||
|
||||
# TODO: remove these 2 conditions once the above TODOs (above loss) are implemented
|
||||
# (Also, `TFTransfoXLLMHeadModel` has no `loss` while `TransfoXLLMHeadModel` return `losses`)
|
||||
if tf_keys != pt_keys:
|
||||
if model_class.__name__ not in [
|
||||
"FlaubertWithLMHeadModel",
|
||||
"FunnelForPreTraining",
|
||||
"ElectraForPreTraining",
|
||||
"XLMWithLMHeadModel",
|
||||
"TransfoXLLMHeadModel",
|
||||
]:
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
tf_outputs = tf_outputs.numpy()
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
# Since we deliberately make some tests pass above (regarding the `loss`), let's still try to test
|
||||
# some remaining attributes in the outputs.
|
||||
# TODO: remove this block of `index` computing once the above TODOs (above loss) are implemented
|
||||
# compute the 1st `index` where `tf_keys` and `pt_keys` is different
|
||||
index = 0
|
||||
for _ in range(min(len(tf_keys), len(pt_keys))):
|
||||
if tf_keys[index] == pt_keys[index]:
|
||||
index += 1
|
||||
else:
|
||||
break
|
||||
if tf_keys[:index] != pt_keys[:index]:
|
||||
self.assertEqual(tf_keys, pt_keys)
|
||||
self.assertEqual(
|
||||
tf_outputs.shape, pt_outputs.shape, f"{name}: Output shapes differ between TF and PyTorch"
|
||||
)
|
||||
|
||||
# Some models require extra condition to return loss. For example, `(TF)BertForPreTraining` requires
|
||||
# both`labels` and `next_sentence_label`.
|
||||
if tf_loss is not None and pt_loss is not None:
|
||||
# deal with NumPy's scalars to make replacing nan values by 0 work.
|
||||
if np.isscalar(tf_outputs):
|
||||
tf_outputs = np.array([tf_outputs])
|
||||
pt_outputs = np.array([pt_outputs])
|
||||
|
||||
# check anything else than `loss`
|
||||
keys = tuple([k for k in tf_keys])
|
||||
check_outputs(tf_outputs[1:index], pt_outputs[1:index], model_class, names=keys[1:index])
|
||||
tf_nans = np.isnan(tf_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
# check `loss`
|
||||
pt_outputs[tf_nans] = 0
|
||||
tf_outputs[tf_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
tf_outputs[pt_nans] = 0
|
||||
|
||||
# tf models returned loss is usually a tensor rather than a scalar.
|
||||
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
|
||||
# Change it here to a scalar to match PyTorch models' loss
|
||||
tf_loss = tf.math.reduce_mean(tf_loss).numpy()
|
||||
pt_loss = pt_loss.detach().to("cpu").numpy()
|
||||
max_diff = np.amax(np.abs(tf_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, tol, f"{name}: Difference between torch and tf is {max_diff} (>= {tol}).")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`tf_outputs` should be an instance of `tf.Tensor`, a `tuple`, or an instance of `tf.Tensor`. Got {type(tf_outputs)} instead."
|
||||
)
|
||||
|
||||
tf_nans = np.isnan(tf_loss)
|
||||
pt_nans = np.isnan(pt_loss)
|
||||
# the 2 losses need to be both nan or both not nan
|
||||
self.assertEqual(tf_nans, pt_nans)
|
||||
def prepare_tf_inputs_from_pt_inputs(self, pt_inputs_dict):
|
||||
|
||||
if not tf_nans:
|
||||
max_diff = np.amax(np.abs(tf_loss - pt_loss))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
tf_inputs_dict = {}
|
||||
for key, tensor in pt_inputs_dict.items():
|
||||
# skip key that does not exist in tf
|
||||
if type(tensor) == bool:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
elif key == "input_features":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
# other general float inputs
|
||||
elif tensor.is_floating_point():
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.cpu().numpy(), dtype=tf.int32)
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
return tf_inputs_dict
|
||||
|
||||
def check_pt_tf_models(self, tf_model, pt_model, pt_inputs_dict):
|
||||
|
||||
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
|
||||
# send pytorch inputs to the correct device
|
||||
pt_inputs_dict = {
|
||||
k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs_dict.items()
|
||||
}
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
# Check predictions on first output (logits/hidden-states) are close enough given low-level computational differences
|
||||
pt_model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs_dict)
|
||||
tf_outputs = tf_model(tf_inputs_dict)
|
||||
|
||||
# tf models returned loss is usually a tensor rather than a scalar.
|
||||
# (see `hf_compute_loss`: it uses `tf.keras.losses.Reduction.NONE`)
|
||||
# Change it here to a scalar to match PyTorch models' loss
|
||||
tf_loss = getattr(tf_outputs, "loss", None)
|
||||
if tf_loss is not None:
|
||||
tf_outputs.loss = tf.math.reduce_mean(tf_loss)
|
||||
|
||||
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning
|
||||
if not hasattr(transformers, tf_model_class_name):
|
||||
# transformers does not have TF version yet
|
||||
# transformers does not have this model in TF version yet
|
||||
return
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
config.output_attentions = self.has_attentions
|
||||
|
||||
for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]:
|
||||
if k in inputs_dict:
|
||||
attention_mask = inputs_dict[k]
|
||||
# make sure no all 0s attention masks - to avoid failure at this moment.
|
||||
# TODO: remove this line once the TODO below is implemented.
|
||||
attention_mask = torch.ones_like(attention_mask, dtype=torch.int32)
|
||||
# Here we make the first sequence with all 0s as attention mask.
|
||||
# Currently, this will fail for `TFWav2Vec2Model`. This is caused by the different large negative
|
||||
# values, like `1e-4`, `1e-9`, `1e-30` and `-inf` for attention mask across models/frameworks.
|
||||
# TODO: enable this block once the large negative values thing is cleaned up.
|
||||
# (see https://github.com/huggingface/transformers/issues/14859)
|
||||
# attention_mask = torch.cat(
|
||||
# [
|
||||
# torch.zeros_like(attention_mask[:1], dtype=torch.int32),
|
||||
# attention_mask[1:].type(dtype=torch.int32)
|
||||
# ],
|
||||
# dim=0
|
||||
# )
|
||||
inputs_dict[k] = attention_mask
|
||||
# Make sure no sequence has all zeros as attention mask, otherwise some tests fail due to the inconsistency
|
||||
# of the usage `1e-4`, `1e-9`, `1e-30`, `-inf`.
|
||||
# TODO: Use a uniform value for all models, make sure all tests pass without this processing, and remove it.
|
||||
self._make_attention_mask_non_null(inputs_dict)
|
||||
|
||||
tf_model_class = getattr(transformers, tf_model_class_name)
|
||||
|
||||
tf_model = tf_model_class(config)
|
||||
pt_model = model_class(config)
|
||||
tf_model = tf_model_class(config)
|
||||
|
||||
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs_dict_with_labels = self._prepare_for_class(
|
||||
inputs_dict,
|
||||
model_class,
|
||||
# Not all models accept "labels" in the forward pass (yet :) )
|
||||
return_labels=True if "labels" in inspect.signature(model_class.forward).parameters.keys() else False,
|
||||
)
|
||||
|
||||
# make sure only tf inputs are forward that actually exist in function args
|
||||
tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys())
|
||||
@@ -1718,20 +1731,25 @@ class ModelTesterMixin:
|
||||
tf_input_keys.discard("cross_attn_head_mask")
|
||||
tf_input_keys.discard("decoder_head_mask")
|
||||
|
||||
pt_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs_dict_maybe_with_labels = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||
|
||||
pt_inputs_dict = {k: v for k, v in pt_inputs_dict.items() if k in tf_input_keys}
|
||||
pt_inputs_dict_maybe_with_labels = {
|
||||
k: v for k, v in pt_inputs_dict_maybe_with_labels.items() if k in tf_input_keys
|
||||
}
|
||||
pt_inputs_dict_with_labels = {k: v for k, v in pt_inputs_dict_with_labels.items() if k in tf_input_keys}
|
||||
|
||||
# For some models (e.g. base models), there is no label returned.
|
||||
# Set the input dict to `None` to avoid check outputs twice for the same input dicts.
|
||||
if set(pt_inputs_dict_with_labels.keys()).symmetric_difference(pt_inputs_dict.keys()):
|
||||
pt_inputs_dict_with_labels = None
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
tf_inputs_dict = prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
# Here requires `tf_inputs_dict` to build `tf_model`
|
||||
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
# check with `labels`
|
||||
if pt_inputs_dict_with_labels:
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict_with_labels)
|
||||
|
||||
# Check we can load pt model in tf and vice-versa with checkpoint => model functions
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
@@ -1742,9 +1760,12 @@ class ModelTesterMixin:
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = pt_model.to(torch_device)
|
||||
|
||||
check_pt_tf_models(tf_model, pt_model, pt_inputs_dict, pt_inputs_dict_maybe_with_labels)
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
# check with `labels`
|
||||
if pt_inputs_dict_with_labels:
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict_with_labels)
|
||||
|
||||
def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float):
|
||||
diff = np.abs((a - b)).max()
|
||||
|
||||
Reference in New Issue
Block a user