Flax dtype-dependent numerical masking (#21197)
This commit is contained in:
@@ -312,7 +312,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
@@ -1859,7 +1859,7 @@ class Flax{{cookiecutter.camelcase_modelname}}Attention(nn.Module):
|
||||
attention_bias = lax.select(
|
||||
attention_mask > 0,
|
||||
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, float("-inf")).astype(self.dtype),
|
||||
jnp.full(attention_mask.shape, jnp.finfo(self.dtype).min).astype(self.dtype),
|
||||
)
|
||||
else:
|
||||
attention_bias = None
|
||||
|
||||
Reference in New Issue
Block a user