From d481b6414d5c07e7ae76ac19e32e1305d3b85b5f Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Fri, 18 Mar 2022 18:15:36 +0100 Subject: [PATCH] Make Flax pt-flax equivalence test more aggressive (#15841) * Make test_equivalence_pt_to_flax more aggressive * Make test_equivalence_flax_to_pt more aggressive * don't use to_tuple * clean-up * fix missing test cases + testing on GPU * fix conversion * fix `ValueError: assignment destination is read-only` * Add type checking * commit to revert later * Fix * fix * fix device * better naming * clean-up Co-authored-by: ydshieh --- tests/test_modeling_flax_common.py | 129 +++++++++++++++++++++++------ 1 file changed, 102 insertions(+), 27 deletions(-) diff --git a/tests/test_modeling_flax_common.py b/tests/test_modeling_flax_common.py index 1edd41aab0..0de5005fe6 100644 --- a/tests/test_modeling_flax_common.py +++ b/tests/test_modeling_flax_common.py @@ -26,7 +26,15 @@ from huggingface_hub import delete_repo, login from requests.exceptions import HTTPError from transformers import BertConfig, is_flax_available, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import PASS, USER, CaptureLogger, is_pt_flax_cross_test, is_staging_test, require_flax +from transformers.testing_utils import ( + PASS, + USER, + CaptureLogger, + is_pt_flax_cross_test, + is_staging_test, + require_flax, + torch_device, +) from transformers.utils import logging @@ -160,15 +168,64 @@ class FlaxModelTesterMixin: dict_inputs = self._prepare_for_class(inputs_dict, model_class) check_equivalence(model, tuple_inputs, dict_inputs, {"output_hidden_states": True}) + def check_outputs(self, fx_outputs, pt_outputs, model_class, names): + """ + Args: + model_class: The class of the model that is currently testing. For example, ..., etc. + Currently unused, but it could make debugging easier and faster. + + names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs. + Currently unused, but in the future, we could use this information to make the error message clearer + by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. + """ + if type(fx_outputs) in [tuple, list]: + self.assertEqual(type(fx_outputs), type(pt_outputs)) + self.assertEqual(len(fx_outputs), len(pt_outputs)) + if type(names) == tuple: + for fo, po, name in zip(fx_outputs, pt_outputs, names): + self.check_outputs(fo, po, model_class, names=name) + elif type(names) == str: + for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)): + self.check_outputs(fo, po, model_class, names=f"{names}_{idx}") + else: + raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.") + elif isinstance(fx_outputs, jnp.ndarray): + self.assertTrue(isinstance(pt_outputs, torch.Tensor)) + + # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`. + fx_outputs = np.array(fx_outputs) + pt_outputs = pt_outputs.detach().to("cpu").numpy() + + fx_nans = np.isnan(fx_outputs) + pt_nans = np.isnan(pt_outputs) + + pt_outputs[fx_nans] = 0 + fx_outputs[fx_nans] = 0 + pt_outputs[pt_nans] = 0 + fx_outputs[pt_nans] = 0 + + max_diff = np.amax(np.abs(fx_outputs - pt_outputs)) + self.assertLessEqual(max_diff, 1e-5) + else: + raise ValueError( + f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." + ) + @is_pt_flax_cross_test def test_equivalence_pt_to_flax(self): + # It might be better to put this inside the for loop below (because we modify the config there). + # But logically, it is fine. config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning @@ -183,24 +240,30 @@ class FlaxModelTesterMixin: fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) fx_model.params = fx_state - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() + # send pytorch model to the correct device + pt_model.to(torch_device) - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = model_class.from_pretrained(tmpdirname, from_pt=True) - fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + fx_outputs_loaded = fx_model_loaded(**prepared_inputs_dict) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): @@ -208,9 +271,14 @@ class FlaxModelTesterMixin: for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + + # Output all for aggressive testing + config.output_hidden_states = True + # Pure convolutional models have no attention + # prepare inputs prepared_inputs_dict = self._prepare_for_class(inputs_dict, model_class) - pt_inputs = {k: torch.tensor(v.tolist()) for k, v in prepared_inputs_dict.items()} + pt_inputs = {k: torch.tensor(v.tolist(), device=torch_device) for k, v in prepared_inputs_dict.items()} # load corresponding PyTorch class pt_model_class_name = model_class.__name__[4:] # Skip the "Flax" at the beginning @@ -227,27 +295,34 @@ class FlaxModelTesterMixin: # make sure weights are tied in PyTorch pt_model.tie_weights() + # send pytorch model to the correct device + pt_model.to(torch_device) + with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**prepared_inputs_dict) - fx_outputs = fx_model(**prepared_inputs_dict).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) pt_model_loaded = pt_model_class.from_pretrained(tmpdirname, from_flax=True) - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys) def test_from_pretrained_save_pretrained(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()