[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)
* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes * [FlaxRoberta] Fix non-broadcastable attention mask * Use jax.numpy instead of ordinary numpy (otherwise not jit-able) * Partially revert "Use jax.numpy ..." * Add tests for batched forward passes * Avoid unnecessary OOMs due to preallocation of GPU memory by XLA * Auto-fix style * Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
This commit is contained in:
committed by
Lysandre
parent
563efd36ab
commit
911d8486e8
@@ -183,6 +183,10 @@ class FlaxBertAttention(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
|
||||
@@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
@@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = np.arange(
|
||||
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
position_ids = jnp.arange(
|
||||
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from numpy import ndarray
|
||||
|
||||
from transformers import BertTokenizerFast, TensorType, is_flax_available, is_torch_available
|
||||
@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
import jax
|
||||
from transformers.models.bert.modeling_flax_bert import FlaxBertModel
|
||||
|
||||
if is_torch_available():
|
||||
@@ -39,3 +45,26 @@ class FlaxBertModelTest(unittest.TestCase):
|
||||
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||
diff = (a - b).sum()
|
||||
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||
|
||||
|
||||
@require_flax
|
||||
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
|
||||
def test_multiple_sentences(jit):
|
||||
tokenizer = BertTokenizerFast.from_pretrained("bert-base-cased")
|
||||
model = FlaxBertModel.from_pretrained("bert-base-cased")
|
||||
|
||||
sentences = ["this is an example sentence", "this is another", "and a third one"]
|
||||
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask, token_type_ids):
|
||||
return model(input_ids, attention_mask, token_type_ids)
|
||||
|
||||
if jit == "disable_jit":
|
||||
with jax.disable_jit():
|
||||
tokens, pooled = model_jitted(**encodings)
|
||||
else:
|
||||
tokens, pooled = model_jitted(**encodings)
|
||||
|
||||
assert tokens.shape == (3, 7, 768)
|
||||
assert pooled.shape == (3, 768)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
from numpy import ndarray
|
||||
|
||||
from transformers import RobertaTokenizerFast, TensorType, is_flax_available, is_torch_available
|
||||
@@ -7,6 +8,11 @@ from transformers.testing_utils import require_flax, require_torch
|
||||
|
||||
|
||||
if is_flax_available():
|
||||
import os
|
||||
|
||||
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.12" # assumed parallelism: 8
|
||||
|
||||
import jax
|
||||
from transformers.models.roberta.modeling_flax_roberta import FlaxRobertaModel
|
||||
|
||||
if is_torch_available():
|
||||
@@ -39,3 +45,26 @@ class FlaxRobertaModelTest(unittest.TestCase):
|
||||
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
|
||||
diff = (a - b).sum()
|
||||
self.assertLessEqual(diff, tol, "Difference between torch and flax is {} (>= {})".format(diff, tol))
|
||||
|
||||
|
||||
@require_flax
|
||||
@pytest.mark.parametrize("jit", ["disable_jit", "enable_jit"])
|
||||
def test_multiple_sentences(jit):
|
||||
tokenizer = RobertaTokenizerFast.from_pretrained("roberta-base")
|
||||
model = FlaxRobertaModel.from_pretrained("roberta-base")
|
||||
|
||||
sentences = ["this is an example sentence", "this is another", "and a third one"]
|
||||
encodings = tokenizer(sentences, return_tensors=TensorType.JAX, padding=True, truncation=True)
|
||||
|
||||
@jax.jit
|
||||
def model_jitted(input_ids, attention_mask):
|
||||
return model(input_ids, attention_mask)
|
||||
|
||||
if jit == "disable_jit":
|
||||
with jax.disable_jit():
|
||||
tokens, pooled = model_jitted(**encodings)
|
||||
else:
|
||||
tokens, pooled = model_jitted(**encodings)
|
||||
|
||||
assert tokens.shape == (3, 7, 768)
|
||||
assert pooled.shape == (3, 768)
|
||||
|
||||
Reference in New Issue
Block a user