diff --git a/src/transformers/models/albert/modeling_flax_albert.py b/src/transformers/models/albert/modeling_flax_albert.py index 5fcec8d5bf..50c46b4040 100644 --- a/src/transformers/models/albert/modeling_flax_albert.py +++ b/src/transformers/models/albert/modeling_flax_albert.py @@ -314,35 +314,13 @@ class FlaxAlbertLayer(nn.Module): return outputs -class FlaxAlbertLayers(nn.Module): - config: AlbertConfig - dtype: jnp.dtype = jnp.float32 # the dtype of the computation - layer_index: Optional[str] = None - - def setup(self): - self.albert_layers = FlaxAlbertLayer(self.config, name=self.layer_index, dtype=self.dtype) - - def __call__( - self, - hidden_states, - attention_mask, - deterministic: bool = True, - output_attentions: bool = False, - ): - outputs = self.albert_layers( - hidden_states, attention_mask, deterministic=deterministic, output_attentions=output_attentions - ) - return outputs - - class FlaxAlbertLayerCollection(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): self.layers = [ - FlaxAlbertLayers(self.config, name="albert_layers", layer_index=str(i), dtype=self.dtype) - for i in range(self.config.inner_group_num) + FlaxAlbertLayer(self.config, name=str(i), dtype=self.dtype) for i in range(self.config.inner_group_num) ] def __call__( @@ -385,7 +363,7 @@ class FlaxAlbertLayerCollections(nn.Module): layer_index: Optional[str] = None def setup(self): - self.albert_layer_groups = FlaxAlbertLayerCollection(self.config, name=self.layer_index, dtype=self.dtype) + self.albert_layers = FlaxAlbertLayerCollection(self.config, dtype=self.dtype) def __call__( self, @@ -395,7 +373,7 @@ class FlaxAlbertLayerCollections(nn.Module): output_attentions: bool = False, output_hidden_states: bool = False, ): - outputs = self.albert_layer_groups( + outputs = self.albert_layers( hidden_states, attention_mask, deterministic=deterministic, @@ -405,19 +383,13 @@ class FlaxAlbertLayerCollections(nn.Module): return outputs -class FlaxAlbertEncoder(nn.Module): +class FlaxAlbertLayerGroups(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 # the dtype of the computation def setup(self): - self.layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups) - self.embedding_hidden_mapping_in = nn.Dense( - self.config.hidden_size, - kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), - dtype=self.dtype, - ) self.layers = [ - FlaxAlbertLayerCollections(self.config, name="albert_layer_groups", layer_index=str(i), dtype=self.dtype) + FlaxAlbertLayerCollections(self.config, name=str(i), layer_index=str(i), dtype=self.dtype) for i in range(self.config.num_hidden_groups) ] @@ -430,7 +402,6 @@ class FlaxAlbertEncoder(nn.Module): output_hidden_states: bool = False, return_dict: bool = True, ): - hidden_states = self.embedding_hidden_mapping_in(hidden_states) all_attentions = () if output_attentions else None all_hidden_states = (hidden_states,) if output_hidden_states else None @@ -459,6 +430,37 @@ class FlaxAlbertEncoder(nn.Module): ) +class FlaxAlbertEncoder(nn.Module): + config: AlbertConfig + dtype: jnp.dtype = jnp.float32 # the dtype of the computation + + def setup(self): + self.embedding_hidden_mapping_in = nn.Dense( + self.config.hidden_size, + kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), + dtype=self.dtype, + ) + self.albert_layer_groups = FlaxAlbertLayerGroups(self.config, dtype=self.dtype) + + def __call__( + self, + hidden_states, + attention_mask, + deterministic: bool = True, + output_attentions: bool = False, + output_hidden_states: bool = False, + return_dict: bool = True, + ): + hidden_states = self.embedding_hidden_mapping_in(hidden_states) + return self.albert_layer_groups( + hidden_states, + attention_mask, + deterministic=deterministic, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + + class FlaxAlbertOnlyMLMHead(nn.Module): config: AlbertConfig dtype: jnp.dtype = jnp.float32 diff --git a/src/transformers/models/hubert/modeling_tf_hubert.py b/src/transformers/models/hubert/modeling_tf_hubert.py index 44867fe586..39640c43b7 100644 --- a/src/transformers/models/hubert/modeling_tf_hubert.py +++ b/src/transformers/models/hubert/modeling_tf_hubert.py @@ -1222,7 +1222,10 @@ class TFHubertMainLayer(tf.keras.layers.Layer): if inputs["attention_mask"] is not None: # compute real output lengths according to convolution formula output_lengths = self._get_feat_extract_output_lengths(tf.reduce_sum(inputs["attention_mask"], -1)) - attention_mask = tf.sequence_mask(output_lengths, dtype=hidden_states.dtype) + + attention_mask = tf.sequence_mask( + output_lengths, maxlen=shape_list(hidden_states)[1], dtype=hidden_states.dtype + ) hidden_states = self.feature_projection(hidden_states, training=inputs["training"]) diff --git a/tests/test_modeling_big_bird.py b/tests/test_modeling_big_bird.py index 3a4a0b870c..d0859fc08f 100644 --- a/tests/test_modeling_big_bird.py +++ b/tests/test_modeling_big_bird.py @@ -59,7 +59,7 @@ class BigBirdModelTester: num_hidden_layers=2, num_attention_heads=4, intermediate_size=37, - hidden_act="gelu_fast", + hidden_act="gelu_new", hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, max_position_embeddings=256, diff --git a/tests/test_modeling_clip.py b/tests/test_modeling_clip.py index dc3c495248..b5e960bea2 100644 --- a/tests/test_modeling_clip.py +++ b/tests/test_modeling_clip.py @@ -23,9 +23,17 @@ import unittest import numpy as np import requests +import transformers from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig from transformers.file_utils import is_torch_available, is_vision_available -from transformers.testing_utils import require_torch, require_vision, slow, torch_device +from transformers.testing_utils import ( + is_flax_available, + is_pt_flax_cross_test, + require_torch, + require_vision, + slow, + torch_device, +) from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, _config_zero_init, floats_tensor, ids_tensor, random_attention_mask @@ -45,6 +53,14 @@ if is_vision_available(): from transformers import CLIPProcessor +if is_flax_available(): + import jax.numpy as jnp + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + + class CLIPVisionModelTester: def __init__( self, @@ -330,6 +346,13 @@ class CLIPTextModelTester: if self.use_input_mask: input_mask = random_attention_mask([self.batch_size, self.seq_length]) + if input_mask is not None: + batch_size, seq_length = input_mask.shape + rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,)) + for batch_idx, start_index in enumerate(rnd_start_indices): + input_mask[batch_idx, :start_index] = 1 + input_mask[batch_idx, start_index:] = 0 + config = self.get_config() return config, input_ids, input_mask @@ -558,6 +581,125 @@ class CLIPModelTest(ModelTesterMixin, unittest.TestCase): self.assertTrue(models_equal) + # overwrite from common since FlaxCLIPModel returns nested output + # which is not supported in the common test + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + return + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_outputs = fx_model(**fx_inputs).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() + self.assertEqual( + len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" + ) + for fx_output_loaded, pt_output in zip(fx_outputs_loaded[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + + # overwrite from common since FlaxCLIPModel returns nested output + # which is not supported in the common test + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # load corresponding PyTorch class + pt_model = model_class(config).eval() + + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_outputs = fx_model(**fx_inputs).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + + for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs[:4]): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual( + len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" + ) + for fx_output, pt_output in zip(fx_outputs[:4], pt_outputs_loaded[:4]): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + @slow def test_model_from_pretrained(self): for model_name in CLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 889c2e830f..3890198edb 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -25,10 +25,13 @@ import unittest import warnings from typing import Dict, List, Tuple +import numpy as np + +import transformers from huggingface_hub import HfApi, Repository from requests.exceptions import HTTPError from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging -from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available +from transformers.file_utils import WEIGHTS_NAME, is_flax_available, is_torch_fx_available from transformers.models.auto import get_values from transformers.testing_utils import ( ENDPOINT_STAGING, @@ -36,6 +39,8 @@ from transformers.testing_utils import ( USER, CaptureLogger, TestCasePlus, + is_pt_flax_cross_test, + is_pt_tf_cross_test, is_staging_test, require_torch, require_torch_multi_gpu, @@ -45,7 +50,6 @@ from transformers.testing_utils import ( if is_torch_available(): - import numpy as np import torch from torch import nn @@ -70,6 +74,13 @@ if is_torch_available(): T5ForConditionalGeneration, ) +if is_flax_available(): + import jax.numpy as jnp + from transformers.modeling_flax_pytorch_utils import ( + convert_pytorch_state_dict_to_flax, + load_flax_weights_in_pytorch_model, + ) + if is_torch_fx_available(): from transformers.utils.fx import symbolic_trace @@ -1417,6 +1428,241 @@ class ModelTesterMixin: model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True} ) + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self): + import numpy as np + import tensorflow as tf + + import transformers + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning + + if not hasattr(transformers, tf_model_class_name): + # transformers does not have TF version yet + return + + tf_model_class = getattr(transformers, tf_model_class_name) + + config.output_hidden_states = True + + tf_model = tf_model_class(config) + pt_model = model_class(config) + + # make sure only tf inputs are forward that actually exist in function args + tf_input_keys = set(inspect.signature(tf_model.call).parameters.keys()) + + # remove all head masks + tf_input_keys.discard("head_mask") + tf_input_keys.discard("cross_attn_head_mask") + tf_input_keys.discard("decoder_head_mask") + + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + pt_inputs = {k: v for k, v in pt_inputs.items() if k in tf_input_keys} + + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + tf_inputs_dict = {} + for key, tensor in pt_inputs.items(): + # skip key that does not exist in tf + if type(tensor) == bool: + tf_inputs_dict[key] = tensor + elif key == "input_values": + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) + else: + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) + + # Check we can load pt model in tf and vice-versa with model => model functions + tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) + pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) + + # need to rename encoder-decoder "inputs" for PyTorch + # if "inputs" in pt_inputs_dict and self.is_encoder_decoder: + # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") + + with torch.no_grad(): + pto = pt_model(**pt_inputs) + tfo = tf_model(tf_inputs_dict, training=False) + + tf_hidden_states = tfo[0].numpy() + pt_hidden_states = pto[0].numpy() + + tf_nans = np.copy(np.isnan(tf_hidden_states)) + pt_nans = np.copy(np.isnan(pt_hidden_states)) + + pt_hidden_states[tf_nans] = 0 + tf_hidden_states[tf_nans] = 0 + pt_hidden_states[pt_nans] = 0 + tf_hidden_states[pt_nans] = 0 + + max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) + self.assertLessEqual(max_diff, 4e-2) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) + + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + tf_inputs_dict = {} + for key, tensor in pt_inputs.items(): + # skip key that does not exist in tf + if type(tensor) == bool: + tensor = np.array(tensor, dtype=bool) + tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32) + elif key == "input_values": + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32) + else: + tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32) + + # need to rename encoder-decoder "inputs" for PyTorch + # if "inputs" in pt_inputs_dict and self.is_encoder_decoder: + # pt_inputs_dict["input_ids"] = pt_inputs_dict.pop("inputs") + + with torch.no_grad(): + pto = pt_model(**pt_inputs) + + tfo = tf_model(tf_inputs_dict) + tfo = tfo[0].numpy() + pto = pto[0].numpy() + tf_nans = np.copy(np.isnan(tfo)) + pt_nans = np.copy(np.isnan(pto)) + + pto[tf_nans] = 0 + tfo[tf_nans] = 0 + pto[pt_nans] = 0 + tfo[pt_nans] = 0 + + max_diff = np.amax(np.abs(tfo - pto)) + self.assertLessEqual(max_diff, 4e-2) + + def assert_almost_equals(self, a: np.ndarray, b: np.ndarray, tol: float): + diff = np.abs((a - b)).max() + self.assertLessEqual(diff, tol, f"Difference between torch and flax is {diff} (>= {tol}).") + + @is_pt_flax_cross_test + def test_equivalence_pt_to_flax(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + + # load PyTorch class + pt_model = model_class(config).eval() + # Flax models don't use the `use_cache` option and cache is not returned as a default. + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + return + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + fx_state = convert_pytorch_state_dict_to_flax(pt_model.state_dict(), fx_model) + fx_model.params = fx_state + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + # convert inputs to Flax + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + fx_outputs = fx_model(**fx_inputs).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + pt_model.save_pretrained(tmpdirname) + fx_model_loaded = fx_model_class.from_pretrained(tmpdirname, from_pt=True) + + fx_outputs_loaded = fx_model_loaded(**fx_inputs).to_tuple() + self.assertEqual( + len(fx_outputs_loaded), len(pt_outputs), "Output lengths differ between Flax and PyTorch" + ) + for fx_output_loaded, pt_output in zip(fx_outputs_loaded, pt_outputs): + self.assert_almost_equals(fx_output_loaded, pt_output.numpy(), 4e-2) + + @is_pt_flax_cross_test + def test_equivalence_flax_to_pt(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + with self.subTest(model_class.__name__): + # load corresponding PyTorch class + pt_model = model_class(config).eval() + + # So we disable `use_cache` here for PyTorch model. + pt_model.config.use_cache = False + + fx_model_class_name = "Flax" + model_class.__name__ + + if not hasattr(transformers, fx_model_class_name): + # no flax model exists for this class + return + + fx_model_class = getattr(transformers, fx_model_class_name) + + # load Flax class + fx_model = fx_model_class(config, dtype=jnp.float32) + # make sure only flax inputs are forward that actually exist in function args + fx_input_keys = inspect.signature(fx_model.__call__).parameters.keys() + + pt_model = load_flax_weights_in_pytorch_model(pt_model, fx_model.params) + + # make sure weights are tied in PyTorch + pt_model.tie_weights() + + # prepare inputs + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + # remove function args that don't exist in Flax + pt_inputs = {k: v for k, v in pt_inputs.items() if k in fx_input_keys} + + with torch.no_grad(): + pt_outputs = pt_model(**pt_inputs).to_tuple() + + fx_inputs = {k: np.array(v) for k, v in pt_inputs.items() if torch.is_tensor(v)} + + fx_outputs = fx_model(**fx_inputs).to_tuple() + self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") + + for fx_output, pt_output in zip(fx_outputs, pt_outputs): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + + with tempfile.TemporaryDirectory() as tmpdirname: + fx_model.save_pretrained(tmpdirname) + pt_model_loaded = model_class.from_pretrained(tmpdirname, from_flax=True) + + with torch.no_grad(): + pt_outputs_loaded = pt_model_loaded(**pt_inputs).to_tuple() + + self.assertEqual( + len(fx_outputs), len(pt_outputs_loaded), "Output lengths differ between Flax and PyTorch" + ) + for fx_output, pt_output in zip(fx_outputs, pt_outputs_loaded): + self.assert_almost_equals(fx_output, pt_output.numpy(), 4e-2) + def test_inputs_embeds(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() diff --git a/tests/test_modeling_longformer.py b/tests/test_modeling_longformer.py index 7bff80ed13..2e506348e4 100644 --- a/tests/test_modeling_longformer.py +++ b/tests/test_modeling_longformer.py @@ -72,7 +72,7 @@ class LongformerModelTester: # because its local attention only attends to `self.attention_window + 1` locations # (assuming no token with global attention, otherwise the last dimension of attentions # is x + self.attention_window + 1, where x is the number of tokens with global attention) - self.key_length = self.attention_window + 1 + self.key_length = self.attention_window + 2 # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # the `test_attention_outputs` and `test_hidden_states_output` tests @@ -243,6 +243,8 @@ class LongformerModelTester: choice_labels, ) = config_and_inputs global_attention_mask = torch.zeros_like(input_ids) + global_attention_mask[:, -1] = 1 + inputs_dict = { "input_ids": input_ids, "token_type_ids": token_type_ids, diff --git a/tests/test_modeling_lxmert.py b/tests/test_modeling_lxmert.py index e881f8ccfd..bbc123275e 100644 --- a/tests/test_modeling_lxmert.py +++ b/tests/test_modeling_lxmert.py @@ -15,13 +15,16 @@ import copy +import os +import tempfile import unittest import numpy as np -from transformers import LxmertConfig, is_torch_available +import transformers +from transformers import LxmertConfig, is_tf_available, is_torch_available from transformers.models.auto import get_values -from transformers.testing_utils import require_torch, slow, torch_device +from transformers.testing_utils import is_pt_tf_cross_test, require_torch, slow, torch_device from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, ids_tensor @@ -40,6 +43,10 @@ if is_torch_available(): from transformers.models.lxmert.modeling_lxmert import LXMERT_PRETRAINED_MODEL_ARCHIVE_LIST +if is_tf_available(): + import tensorflow as tf + + class LxmertModelTester: """You can also import this e.g from .test_modeling_bart import BartModelTester""" @@ -496,7 +503,7 @@ class LxmertModelTester: result_pretrain_more.question_answering_score.shape, (self.batch_size, num_large_labels) ) - def prepare_config_and_inputs_for_common(self): + def prepare_config_and_inputs_for_common(self, return_obj_labels=False): config_and_inputs = self.prepare_config_and_inputs() ( config, @@ -520,6 +527,9 @@ class LxmertModelTester: "attention_mask": input_mask, } + if return_obj_labels: + inputs_dict["obj_labels"] = obj_labels + return config, inputs_dict @@ -732,6 +742,128 @@ class LxmertModelTest(ModelTesterMixin, unittest.TestCase): self.assertIsNotNone(hidden_states_vision.grad) self.assertIsNotNone(attentions_vision.grad) + @is_pt_tf_cross_test + def test_pt_tf_model_equivalence(self): + for model_class in self.all_model_classes: + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common( + return_obj_labels="PreTraining" in model_class.__name__ + ) + + tf_model_class_name = "TF" + model_class.__name__ # Add the "TF" at the beginning + + if not hasattr(transformers, tf_model_class_name): + # transformers does not have TF version yet + return + + tf_model_class = getattr(transformers, tf_model_class_name) + + config.output_hidden_states = True + config.task_obj_predict = False + + pt_model = model_class(config) + tf_model = tf_model_class(config) + + # Check we can load pt model in tf and vice-versa with model => model functions + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + + def recursive_numpy_convert(iterable): + return_dict = {} + for key, value in iterable.items(): + if type(value) == bool: + return_dict[key] = value + if isinstance(value, dict): + return_dict[key] = recursive_numpy_convert(value) + else: + if isinstance(value, (list, tuple)): + return_dict[key] = ( + tf.convert_to_tensor(iter_value.numpy(), dtype=tf.int32) for iter_value in value + ) + else: + return_dict[key] = tf.convert_to_tensor(value.numpy(), dtype=tf.int32) + return return_dict + + tf_inputs_dict = recursive_numpy_convert(pt_inputs) + + tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict) + pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model) + + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + + # Delete obj labels as we want to compute the hidden states and not the loss + + if "obj_labels" in inputs_dict: + del inputs_dict["obj_labels"] + + def torch_type(key): + if key in ("visual_feats", "visual_pos"): + return torch.float32 + else: + return torch.long + + pt_inputs = self._prepare_for_class(inputs_dict, model_class) + tf_inputs_dict = recursive_numpy_convert(pt_inputs) + + with torch.no_grad(): + pto = pt_model(**pt_inputs) + tfo = tf_model(tf_inputs_dict, training=False) + tf_hidden_states = tfo[0].numpy() + pt_hidden_states = pto[0].numpy() + + tf_nans = np.copy(np.isnan(tf_hidden_states)) + pt_nans = np.copy(np.isnan(pt_hidden_states)) + + pt_hidden_states[tf_nans] = 0 + tf_hidden_states[tf_nans] = 0 + pt_hidden_states[pt_nans] = 0 + tf_hidden_states[pt_nans] = 0 + + max_diff = np.amax(np.abs(tf_hidden_states - pt_hidden_states)) + # Debug info (remove when fixed) + if max_diff >= 2e-2: + print("===") + print(model_class) + print(config) + print(inputs_dict) + print(pt_inputs) + self.assertLessEqual(max_diff, 6e-2) + + # Check we can load pt model in tf and vice-versa with checkpoint => model functions + with tempfile.TemporaryDirectory() as tmpdirname: + pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin") + torch.save(pt_model.state_dict(), pt_checkpoint_path) + tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path) + + tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5") + tf_model.save_weights(tf_checkpoint_path) + pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path) + + # Check predictions on first output (logits/hidden-states) are close enought given low-level computational differences + pt_model.eval() + + for key, value in pt_inputs.items(): + if key in ("visual_feats", "visual_pos"): + pt_inputs[key] = value.to(torch.float32) + else: + pt_inputs[key] = value.to(torch.long) + + with torch.no_grad(): + pto = pt_model(**pt_inputs) + + tfo = tf_model(tf_inputs_dict) + tfo = tfo[0].numpy() + pto = pto[0].numpy() + tf_nans = np.copy(np.isnan(tfo)) + pt_nans = np.copy(np.isnan(pto)) + + pto[tf_nans] = 0 + tfo[tf_nans] = 0 + pto[pt_nans] = 0 + tfo[pt_nans] = 0 + + max_diff = np.amax(np.abs(tfo - pto)) + self.assertLessEqual(max_diff, 6e-2) + @require_torch class LxmertModelIntegrationTest(unittest.TestCase): diff --git a/tests/test_modeling_tf_common.py b/tests/test_modeling_tf_common.py index 8a6897b038..2709b4ff76 100644 --- a/tests/test_modeling_tf_common.py +++ b/tests/test_modeling_tf_common.py @@ -431,7 +431,6 @@ class TFModelTesterMixin: @is_pt_tf_cross_test def test_pt_tf_model_equivalence(self): - import torch import transformers diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 164e661ff6..c68acbeff1 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -22,7 +22,14 @@ import pytest from tests.test_modeling_common import floats_tensor, ids_tensor, random_attention_mask from transformers import Wav2Vec2Config, is_torch_available -from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device +from transformers.testing_utils import ( + is_pt_flax_cross_test, + require_datasets, + require_soundfile, + require_torch, + slow, + torch_device, +) from .test_configuration_common import ConfigTester from .test_modeling_common import ModelTesterMixin, _config_zero_init @@ -131,6 +138,7 @@ class Wav2Vec2ModelTester: hidden_dropout_prob=self.hidden_dropout_prob, intermediate_size=self.intermediate_size, layer_norm_eps=self.layer_norm_eps, + do_stable_layer_norm=self.do_stable_layer_norm, hidden_act=self.hidden_act, initializer_range=self.initializer_range, vocab_size=self.vocab_size, @@ -357,6 +365,16 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): def test_model_common_attributes(self): pass + @is_pt_flax_cross_test + # non-robust architecture does not exist in Flax + def test_equivalence_flax_to_pt(self): + pass + + @is_pt_flax_cross_test + # non-robust architecture does not exist in Flax + def test_equivalence_pt_to_flax(self): + pass + def test_retain_grad_hidden_states_attentions(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config.output_hidden_states = True