[JAX] Replace uses of jnp.array in types with jnp.ndarray. (#26703)
`jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
@@ -251,7 +251,7 @@ class Flax{{cookiecutter.camelcase_modelname}}SelfAttention(nn.Module):
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
layer_head_mask,
|
||||
key_value_states: Optional[jnp.array] = None,
|
||||
key_value_states: Optional[jnp.ndarray] = None,
|
||||
init_cache: bool = False,
|
||||
deterministic=True,
|
||||
output_attentions: bool = False,
|
||||
|
||||
Reference in New Issue
Block a user