Make Flax pt-flax equivalence test more aggressive (#15841)
* Make test_equivalence_pt_to_flax more aggressive * Make test_equivalence_flax_to_pt more aggressive * don't use to_tuple * clean-up * fix missing test cases + testing on GPU * fix conversion * fix `ValueError: assignment destination is read-only` * Add type checking * commit to revert later * Fix * fix * fix device * better naming * clean-up Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
@@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login
|
|||||||
from requests.exceptions import HTTPError
|
from requests.exceptions import HTTPError
|
||||||
from transformers import BertConfig, is_flax_available, is_torch_available
|
from transformers import BertConfig, is_flax_available, is_torch_available
|
||||||
from transformers.models.auto import get_values
|
from transformers.models.auto import get_values
|
||||||
from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax
|
from transformers.testing_utils import (
|
||||||
|
PASS,
|
||||||
|
USER,
|
||||||
|
CaptureLogger,
|
||||||
|
is_pt_flax_cross_test,
|
||||||
|
is_staging_test,
|
||||||
|
require_flax,
|
||||||
|
torch_device,
|
||||||
|
)
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
|
|
||||||
@@ -160,15 +168,64 @@ class FlaxModelTesterMixin:
|
|||||||
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
dict_inputs = self._prepare_for_class(inputs_dict, model_class)
|
||||||
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True})
|
||||||
|
|
||||||
|
def check_outputs(self, fx_outputs, pt_outputs, model_class, names):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
model_class: The class of the model that is currently testing. For example, ..., etc.
|
||||||
|
Currently unused, but it could make debugging easier and faster.
|
||||||
|
|
||||||
|
names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs.
|
||||||
|
Currently unused, but in the future, we could use this information to make the error message clearer
|
||||||
|
by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax.
|
||||||
|
"""
|
||||||
|
if type(fx_outputs) in [tuple, list]:
|
||||||
|
self.assertEqual(type(fx_outputs), type(pt_outputs))
|
||||||
|
self.assertEqual(len(fx_outputs), len(pt_outputs))
|
||||||
|
if type(names) == tuple:
|
||||||
|
for fo, po, name in zip(fx_outputs, pt_outputs, names):
|
||||||
|
self.check_outputs(fo, po, model_class, names=name)
|
||||||
|
elif type(names) == str:
|
||||||
|
for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)):
|
||||||
|
self.check_outputs(fo, po, model_class, names=f"{names}_{idx}")
|
||||||
|
else:
|
||||||
|
raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.")
|
||||||
|
elif isinstance(fx_outputs, jnp.ndarray):
|
||||||
|
self.assertTrue(isinstance(pt_outputs, torch.Tensor))
|
||||||
|
|
||||||
|
# Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`.
|
||||||
|
fx_outputs = np.array(fx_outputs)
|
||||||
|
pt_outputs = pt_outputs.detach().to("cpu").numpy()
|
||||||
|
|
||||||
|
fx_nans = np.isnan(fx_outputs)
|
||||||
|
pt_nans = np.isnan(pt_outputs)
|
||||||
|
|
||||||
|
pt_outputs[fx_nans] = 0
|
||||||
|
fx_outputs[fx_nans] = 0
|
||||||
|
pt_outputs[pt_nans] = 0
|
||||||
|
fx_outputs[pt_nans] = 0
|
||||||
|
|
||||||
|
max_diff = np.amax(np.abs(fx_outputs - pt_outputs))
|
||||||
|
self.assertLessEqual(max_diff, 1e-5)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead."
|
||||||
|
)
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_pt_to_flax(self):
|
def test_equivalence_pt_to_flax(self):
|
||||||
|
# It might be better to put this inside the for loop below (because we modify the config there).
|
||||||
|
# But logically, it is fine.
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
|
||||||
|
# Output all for aggressive testing
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||||
|
|
||||||
# load corresponding PyTorch class
|
# load corresponding PyTorch class
|
||||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||||
@@ -183,24 +240,30 @@ class FlaxModelTesterMixin:
|
|||||||
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model)
|
||||||
fx_model.params = fx_state
|
fx_model.params = fx_state
|
||||||
|
|
||||||
with torch.no_grad():
|
# send pytorch model to the correct device
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_model.to(torch_device)
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
with torch.no_grad():
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
|
||||||
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
pt_model.save_pretrained(tmpdirname)
|
pt_model.save_pretrained(tmpdirname)
|
||||||
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True)
|
||||||
|
|
||||||
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple()
|
fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict)
|
||||||
self.assertEqual(
|
|
||||||
len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch"
|
fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None])
|
||||||
)
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs):
|
|
||||||
self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2)
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||||
|
|
||||||
@is_pt_flax_cross_test
|
@is_pt_flax_cross_test
|
||||||
def test_equivalence_flax_to_pt(self):
|
def test_equivalence_flax_to_pt(self):
|
||||||
@@ -208,9 +271,14 @@ class FlaxModelTesterMixin:
|
|||||||
|
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
with self.subTest(model_class.__name__):
|
with self.subTest(model_class.__name__):
|
||||||
|
|
||||||
|
# Output all for aggressive testing
|
||||||
|
config.output_hidden_states = True
|
||||||
|
# Pure convolutional models have no attention
|
||||||
|
|
||||||
# prepare inputs
|
# prepare inputs
|
||||||
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()}
|
pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()}
|
||||||
|
|
||||||
# load corresponding PyTorch class
|
# load corresponding PyTorch class
|
||||||
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning
|
||||||
@@ -227,27 +295,34 @@ class FlaxModelTesterMixin:
|
|||||||
# make sure weights are tied in PyTorch
|
# make sure weights are tied in PyTorch
|
||||||
pt_model.tie_weights()
|
pt_model.tie_weights()
|
||||||
|
|
||||||
|
# send pytorch model to the correct device
|
||||||
|
pt_model.to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
pt_outputs = pt_model(**pt_inputs).to_tuple()
|
pt_outputs = pt_model(**pt_inputs)
|
||||||
|
fx_outputs = fx_model(**prepared_inputs_dict)
|
||||||
|
|
||||||
fx_outputs = fx_model(**prepared_inputs_dict).to_tuple()
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
|
pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None])
|
||||||
|
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs):
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys)
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
fx_model.save_pretrained(tmpdirname)
|
fx_model.save_pretrained(tmpdirname)
|
||||||
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True)
|
||||||
|
|
||||||
with torch.no_grad():
|
# send pytorch model to the correct device
|
||||||
pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple()
|
pt_model_loaded.to(torch_device)
|
||||||
|
|
||||||
self.assertEqual(
|
with torch.no_grad():
|
||||||
len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch"
|
pt_outputs_loaded = pt_model_loaded(**pt_inputs)
|
||||||
)
|
|
||||||
for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded):
|
fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None])
|
||||||
self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2)
|
pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None])
|
||||||
|
|
||||||
|
self.assertEqual(fx_keys, pt_keys)
|
||||||
|
self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys)
|
||||||
|
|
||||||
def test_from_pretrained_save_pretrained(self):
|
def test_from_pretrained_save_pretrained(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|||||||
Reference in New Issue
Block a user