Make Flax pt-flax equivalence test more aggressive (#15841)
* Make test_equivalence_pt_to_flax more aggressive * Make test_equivalence_flax_to_pt more aggressive * don't use to_tuple * clean-up * fix missing test cases + testing on GPU * fix conversion * fix `ValueError: assignment destination is read-only` * Add type checking * commit to revert later * Fix * fix * fix device * better naming * clean-up Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login
|
||||
from requests.exceptions import HTTPError
|
||||
from transformers import BertConfig, is_flax_available, is_torch_available
|
||||
from transformers.models.auto import get_values
|
||||
from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
|
||||
from transformers.testing_utils import (
|
||||
PASS,
|
||||
USER,
|
||||
CaptureLogger,
|
||||
is_pt_flax_cross_test,
|
||||
is_staging_test,
|
||||
require_flax,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
@@ -160,15 +168,64 @@ class FlaxModelTesterMixin:
|
||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||
|
||||
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||
"""
|
||||
Args:
|
||||
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||
Currently unused, but it could make debugging easier and faster.
|
||||
|
||||
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||
"""
|
||||
if type(fx_outputs) in [tuple, list]:
|
||||
self.assertEqual(type(fx_outputs), type(pt_outputs))
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs))
|
||||
if type(names) == tuple:
|
||||
for fo, po, name in zip(fx_outputs, pt_outputs, names):
|
||||
self.check_outputs(fo, po, model_class, names=name)
|
||||
elif type(names) == str:
|
||||
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
|
||||
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
|
||||
else:
|
||||
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||
elif isinstance(fx_outputs, jnp.ndarray):
|
||||
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||
|
||||
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||
fx_outputs = np.array(fx_outputs)
|
||||
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||
|
||||
fx_nans = np.isnan(fx_outputs)
|
||||
pt_nans = np.isnan(pt_outputs)
|
||||
|
||||
pt_outputs[fx_nans] = 0
|
||||
fx_outputs[fx_nans] = 0
|
||||
pt_outputs[pt_nans] = 0
|
||||
fx_outputs[pt_nans] = 0
|
||||
|
||||
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
||||
)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_pt_to_flax(self):
|
||||
# It might be better to put this inside the for loop below (because we modify the config there).
|
||||
# But logically, it is fine.
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
@@ -183,24 +240,30 @@ class FlaxModelTesterMixin:
|
||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||
fx_model.params = fx_state
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_model.save_pretrained(tmpdirname)
|
||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(
|
||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
|
||||
@is_pt_flax_cross_test
|
||||
def test_equivalence_flax_to_pt(self):
|
||||
@@ -208,9 +271,14 @@ class FlaxModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
with self.subTest(model_class.__name__):
|
||||
|
||||
# Output all for aggressive testing
|
||||
config.output_hidden_states = True
|
||||
# Pure convolutional models have no attention
|
||||
|
||||
# prepare inputs
|
||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
||||
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||
|
||||
# load corresponding PyTorch class
|
||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||
@@ -227,27 +295,34 @@ class FlaxModelTesterMixin:
|
||||
# make sure weights are tied in PyTorch
|
||||
pt_model.tie_weights()
|
||||
|
||||
# send pytorch model to the correct device
|
||||
pt_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
||||
pt_outputs = pt_model(**pt_inputs)
|
||||
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||
|
||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
fx_model.save_pretrained(tmpdirname)
|
||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
||||
# send pytorch model to the correct device
|
||||
pt_model_loaded.to(torch_device)
|
||||
|
||||
self.assertEqual(
|
||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
||||
)
|
||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
||||
with torch.no_grad():
|
||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||
|
||||
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||
|
||||
self.assertEqual(fx_keys, pt_keys)
|
||||
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user