Make Flax pt-flax equivalence test more aggressive (#15841)

* Make test_equivalence_pt_to_flax more aggressive

* Make test_equivalence_flax_to_pt more aggressive

* don't use to_tuple

* clean-up

* fix missing test cases + testing on GPU

* fix conversion

* fix `ValueError: assignment destination is read-only`

* Add type checking

* commit to revert later

* Fix

* fix

* fix device

* better naming

* clean-up

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar
2022-03-18 18:15:36 +01:00
committed by GitHub
parent c03b6e4259
commit d481b6414d

View File

@@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login
from requests.exceptions import HTTPError
from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
from transformers.testing_utils import (
PASS,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
torch_device,
)
from transformers.utils import logging
@@ -160,15 +168,64 @@ class FlaxModelTesterMixin:
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
"""
Args:
model_class: The class of the model that is currently testing. For example, ..., etc.
Currently unused, but it could make debugging easier and faster.
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
Currently unused, but in the future, we could use this information to make the error message clearer
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
"""
if type(fx_outputs) in [tuple, list]:
self.assertEqual(type(fx_outputs), type(pt_outputs))
self.assertEqual(len(fx_outputs), len(pt_outputs))
if type(names) == tuple:
for fo, po, name in zip(fx_outputs, pt_outputs, names):
self.check_outputs(fo, po, model_class, names=name)
elif type(names) == str:
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
else:
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
elif isinstance(fx_outputs, jnp.ndarray):
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
fx_outputs = np.array(fx_outputs)
pt_outputs = pt_outputs.detach().to("cpu").numpy()
fx_nans = np.isnan(fx_outputs)
pt_nans = np.isnan(pt_outputs)
pt_outputs[fx_nans] = 0
fx_outputs[fx_nans] = 0
pt_outputs[pt_nans] = 0
fx_outputs[pt_nans] = 0
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
self.assertLessEqual(max_diff, 1e-5)
else:
raise ValueError(
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
)
@is_pt_flax_cross_test
def test_equivalence_pt_to_flax(self):
# It might be better to put this inside the for loop below (because we modify the config there).
# But logically, it is fine.
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
@@ -183,24 +240,30 @@ class FlaxModelTesterMixin:
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
fx_model.params = fx_state
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
# send pytorch model to the correct device
pt_model.to(torch_device)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname:
pt_model.save_pretrained(tmpdirname)
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
self.assertEqual(
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
)
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
@is_pt_flax_cross_test
def test_equivalence_flax_to_pt(self):
@@ -208,9 +271,14 @@ class FlaxModelTesterMixin:
for model_class in self.all_model_classes:
with self.subTest(model_class.__name__):
# Output all for aggressive testing
config.output_hidden_states = True
# Pure convolutional models have no attention
# prepare inputs
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
# load corresponding PyTorch class
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
@@ -227,27 +295,34 @@ class FlaxModelTesterMixin:
# make sure weights are tied in PyTorch
pt_model.tie_weights()
# send pytorch model to the correct device
pt_model.to(torch_device)
with torch.no_grad():
pt_outputs = pt_model(**pt_inputs).to_tuple()
pt_outputs = pt_model(**pt_inputs)
fx_outputs = fx_model(**prepared_inputs_dict)
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
with tempfile.TemporaryDirectory() as tmpdirname:
fx_model.save_pretrained(tmpdirname)
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
# send pytorch model to the correct device
pt_model_loaded.to(torch_device)
self.assertEqual(
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
)
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
with torch.no_grad():
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
self.assertEqual(fx_keys, pt_keys)
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
def test_from_pretrained_save_pretrained(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()