[FlaxBert] Fix non-broadcastable attention mask for batched forward-passes (#8791)

* [FlaxBert] Fix non-broadcastable attention mask for batched forward-passes

* [FlaxRoberta] Fix non-broadcastable attention mask

* Use jax.numpy instead of ordinary numpy (otherwise not jit-able)

* Partially revert "Use jax.numpy ..."

* Add tests for batched forward passes

* Avoid unnecessary OOMs due to preallocation of GPU memory by XLA

* Auto-fix style

* Re-enable GPU memory preallocation but with mem fraction < 1/paralleism
This commit is contained in:
Kristian Holsheimer
2020-11-27 23:21:19 +11:00
committed by GitHub
parent cb7602b38d
commit f8eda599bd
4 changed files with 68 additions and 2 deletions

View File

@@ -186,6 +186,10 @@ class FlaxRobertaAttention(nn.Module):
@nn.compact
def __call__(self, hidden_state, attention_mask):
# Attention mask comes in as attention_mask.shape == (*batch_sizes, kv_length)
# FLAX expects: attention_mask.shape == (*batch_sizes, 1, 1, kv_length) such that it is broadcastable
# with attn_weights.shape == (*batch_sizes, num_heads, q_length, kv_length)
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
self_att = nn.attention.SelfAttention(num_heads=self.num_heads, qkv_features=self.head_size, name="self")(
hidden_state, attention_mask
)
@@ -416,8 +420,8 @@ class FlaxRobertaModel(FlaxPreTrainedModel):
token_type_ids = jnp.ones_like(input_ids)
if position_ids is None:
position_ids = np.arange(
self.config.pad_token_id + 1, np.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
position_ids = jnp.arange(
self.config.pad_token_id + 1, jnp.atleast_2d(input_ids).shape[-1] + self.config.pad_token_id + 1
)
if attention_mask is None: