Flax dtype-dependent numerical masking (#21197)

This commit is contained in:
Joao Gante
2023-01-19 16:43:42 +00:00
committed by GitHub
parent 0b86e330b1
commit cbaaa2f6ac
21 changed files with 22 additions and 22 deletions

View File

@@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None
@@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
attention_bias = lax.select(
attention_mask > 0,
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
)
else:
attention_bias = None