[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)
* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes * [FlaxRoberta] Fix non-broadcastable attention mask * Use jax.numpy instead of ordinary numpy (otherwise not jit-able) * Partially revert "Use jax.numpy ..." * Add tests for batched forward passes * Avoid unnecessary OOMs due to preallocation of GPU memory by XLA * Auto-fix style * Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
This commit is contained in:
committed by
GitHub
parent
cb7602b38d
commit
f8eda599bd
@@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module):
|
||||
|
||||
@nn.compact
|
||||
def __call__(self, hidden_state, attention_mask):
|
||||
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
|
||||
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
|
||||
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
|
||||
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
||||
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
|
||||
hidden_state, attention_mask
|
||||
)
|
||||
@@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
|
||||
token_type_ids = jnp.ones_like(input_ids)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = np.arange(
|
||||
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
position_ids = jnp.arange(
|
||||
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
|
||||
)
|
||||
|
||||
if attention_mask is None:
|
||||
|
||||
Reference in New Issue
Block a user