diff --git a/src/transformers/models/bert/modeling_flax_bert.py b/src/transformers/models/bert/modeling_flax_bert.py index a1cbbb87de..d4e9334378 100644 --- a/src/transformers/models/bert/modeling_flax_bert.py +++ b/src/transformers/models/bert/modeling_flax_bert.py @@ -183,6 +183,10 @@ class FlaxBertAttention(nn.Module): @nn.compact def __call__(self, hidden_state, attention_mask): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( hidden_state, attention_mask ) diff --git a/src/transformers/models/roberta/modeling_flax_roberta.py b/src/transformers/models/roberta/modeling_flax_roberta.py index 1e2a76c6b6..c21556f03e 100644 --- a/src/transformers/models/roberta/modeling_flax_roberta.py +++ b/src/transformers/models/roberta/modeling_flax_roberta.py @@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module): @nn.compact def __call__(self, hidden_state, attention_mask): + # Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length) + # FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable + # with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length) + attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2)) self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")( hidden_state, attention_mask ) @@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel): token_type_ids = jnp.ones_like(input_ids) if position_ids is None: - position_ids = np.arange( - self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 + position_ids = jnp.arange( + self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1 ) if attention_mask is None: diff --git a/tests/test_modeling_flax_bert.py b/tests/test_modeling_flax_bert.py index c8c2da1ff1..42d6bfed03 100644 --- a/tests/test_modeling_flax_bert.py +++ b/tests/test_modeling_flax_bert.py @@ -1,5 +1,6 @@ import unittest +import pytest from numpy import ndarray from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available @@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch if is_flax_available(): + import os + + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 + + import jax from transformers.models.bert.modeling_flax_bert import FlaxBertModel if is_torch_available(): @@ -39,3 +45,26 @@ class FlaxBertModelTest(unittest.TestCase): def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): diff = (a - b).sum() self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol)) + + +@require_flax +@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"]) +def test_multiple_sentences(jit): + tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased") + model = FlaxBertModel.from_pretrained("bert-base-cased") + + sentences = ["this is an example sentence", "this is another", "and a third one"] + encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True) + + @jax.jit + def model_jitted(input_ids, attention_mask, token_type_ids): + return model(input_ids, attention_mask, token_type_ids) + + if jit == "disable_jit": + with jax.disable_jit(): + tokens, pooled = model_jitted(**encodings) + else: + tokens, pooled = model_jitted(**encodings) + + assert tokens.shape == (3, 7, 768) + assert pooled.shape == (3, 768) diff --git a/tests/test_modeling_flax_roberta.py b/tests/test_modeling_flax_roberta.py index 7bfdb54a12..f058d54528 100644 --- a/tests/test_modeling_flax_roberta.py +++ b/tests/test_modeling_flax_roberta.py @@ -1,5 +1,6 @@ import unittest +import pytest from numpy import ndarray from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available @@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch if is_flax_available(): + import os + + os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8 + + import jax from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel if is_torch_available(): @@ -39,3 +45,26 @@ class FlaxRobertaModelTest(unittest.TestCase): def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): diff = (a - b).sum() self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol)) + + +@require_flax +@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"]) +def test_multiple_sentences(jit): + tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base") + model = FlaxRobertaModel.from_pretrained("roberta-base") + + sentences = ["this is an example sentence", "this is another", "and a third one"] + encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True) + + @jax.jit + def model_jitted(input_ids, attention_mask): + return model(input_ids, attention_mask) + + if jit == "disable_jit": + with jax.disable_jit(): + tokens, pooled = model_jitted(**encodings) + else: + tokens, pooled = model_jitted(**encodings) + + assert tokens.shape == (3, 7, 768) + assert pooled.shape == (3, 768)