diff --git a/tests/big_bird/test_modeling_big_bird.py b/tests/big_bird/test_modeling_big_bird.py index df2ca8a157..24b88fd423 100644 --- a/tests/big_bird/test_modeling_big_bird.py +++ b/tests/big_bird/test_modeling_big_bird.py @@ -596,6 +596,15 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_for_change_to_full_attn(*config_and_inputs) + # overwrite from common in order to skip the check on `attentions` + def check_outputs(self, fx_outputs, pt_outputs, model_class, names): + # `bigbird_block_sparse_attention` in `FlaxBigBird` returns `attention_probs = None`, while in PyTorch version, + # an effort was done to return `attention_probs` (yet to be verified). + if type(names) == str and names.startswith("attentions"): + return + else: + super().check_outputs(fx_outputs, pt_outputs, model_class, names) + @require_torch @slow diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 07cf086fae..6ed8dd3c97 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1660,8 +1660,9 @@ class ModelTesterMixin: # transformers does not have TF version yet return - if self.has_attentions: - config.output_attentions = True + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions for k in ["attention_mask", "encoder_attention_mask", "decoder_attention_mask"]: if k in inputs_dict: @@ -1728,12 +1729,65 @@ class ModelTesterMixin: diff = np.abs((a - b)).max() self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + def check_outputs(self, fx_outputs, pt_outputs, model_class, names): + """ + Args: + model_class: The class of the model that is currently testing. For example, ..., etc. + Currently unused, but it could make debugging easier and faster. + + names: A string, or a list of strings. These specify what fx_outputs/pt_outputs represent in the model outputs. + Currently unused, but in the future, we could use this information to make the error message clearer + by giving the name(s) of the output tensor(s) with large difference(s) between PT and Flax. + """ + if type(fx_outputs) in [tuple, list]: + self.assertEqual(type(fx_outputs), type(pt_outputs)) + self.assertEqual(len(fx_outputs), len(pt_outputs)) + if type(names) == tuple: + for fo, po, name in zip(fx_outputs, pt_outputs, names): + self.check_outputs(fo, po, model_class, names=name) + elif type(names) == str: + for idx, (fo, po) in enumerate(zip(fx_outputs, pt_outputs)): + self.check_outputs(fo, po, model_class, names=f"{names}_{idx}") + else: + raise ValueError(f"`names` should be a `tuple` or a string. Got {type(names)} instead.") + elif isinstance(fx_outputs, jnp.ndarray): + self.assertTrue(isinstance(pt_outputs, torch.Tensor)) + + # Using `np.asarray` gives `ValueError: assignment destination is read-only` at the line `fx_outputs[fx_nans] = 0`. + fx_outputs = np.array(fx_outputs) + pt_outputs = pt_outputs.detach().to("cpu").numpy() + + fx_nans = np.isnan(fx_outputs) + pt_nans = np.isnan(pt_outputs) + + pt_outputs[fx_nans] = 0 + fx_outputs[fx_nans] = 0 + pt_outputs[pt_nans] = 0 + fx_outputs[pt_nans] = 0 + + self.assert_almost_equals(fx_outputs, pt_outputs, 1e-5) + else: + raise ValueError( + f"`fx_outputs` should be a `tuple` or an instance of `jnp.ndarray`. Got {type(fx_outputs)} instead." + ) + @is_pt_flax_cross_test def test_equivalence_pt_to_flax(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: with self.subTest(model_class.__name__): + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + + fx_model_class = getattr(transformers, fx_model_class_name) # load PyTorch class pt_model = model_class(config).eval() @@ -1741,15 +1795,9 @@ class ModelTesterMixin: # So we disable `use_cache` here for PyTorch model. pt_model.config.use_cache = False - fx_model_class_name = "Flax" + model_class.__name__ - - if not hasattr(transformers, fx_model_class_name): - return - - fx_model_class = getattr(transformers, fx_model_class_name) - # load Flax class fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() @@ -1759,29 +1807,41 @@ class ModelTesterMixin: # remove function args that don't exist in Flax pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) - fx_model.params = fx_state - - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } # convert inputs to Flax fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) with tempfile.TemporaryDirectory() as tmpdirname: pt_model.save_pretrained(tmpdirname) fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) - fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() - self.assertEqual( - len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" - ) - for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): - self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + fx_outputs_loaded = fx_model_loaded(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs_loaded.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs_loaded.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) @is_pt_flax_cross_test def test_equivalence_flax_to_pt(self): @@ -1789,59 +1849,78 @@ class ModelTesterMixin: for model_class in self.all_model_classes: with self.subTest(model_class.__name__): - # load corresponding PyTorch class - pt_model = model_class(config).eval() - - # So we disable `use_cache` here for PyTorch model. - pt_model.config.use_cache = False - fx_model_class_name = "Flax" + model_class.__name__ if not hasattr(transformers, fx_model_class_name): # no flax model exists for this class return + # Output all for aggressive testing + config.output_hidden_states = True + config.output_attentions = self.has_attentions + fx_model_class = getattr(transformers, fx_model_class_name) + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + # load Flax class fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() - pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - - # make sure weights are tied in PyTorch - pt_model.tie_weights() - # prepare inputs pt_inputs = self._prepare_for_class(inputs_dict, model_class) # remove function args that don't exist in Flax pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} - with torch.no_grad(): - pt_outputs = pt_model(**pt_inputs).to_tuple() + # send pytorch inputs to the correct device + pt_inputs = { + k: v.to(device=torch_device) if isinstance(v, torch.Tensor) else v for k, v in pt_inputs.items() + } + # convert inputs to Flax fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} - fx_outputs = fx_model(**fx_inputs).to_tuple() - self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) - for fx_output, pt_output in zip(fx_outputs, pt_outputs): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # send pytorch model to the correct device + pt_model.to(torch_device) + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs) + fx_outputs = fx_model(**fx_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs.to_tuple(), model_class, names=fx_keys) with tempfile.TemporaryDirectory() as tmpdirname: fx_model.save_pretrained(tmpdirname) pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) - with torch.no_grad(): - pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + # send pytorch model to the correct device + pt_model_loaded.to(torch_device) + pt_model_loaded.eval() - self.assertEqual( - len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" - ) - for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): - self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs) + + fx_keys = tuple([k for k, v in fx_outputs.items() if v is not None]) + pt_keys = tuple([k for k, v in pt_outputs_loaded.items() if v is not None]) + + self.assertEqual(fx_keys, pt_keys) + self.check_outputs(fx_outputs.to_tuple(), pt_outputs_loaded.to_tuple(), model_class, names=fx_keys) def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()