[JAX] Replace uses of jnp.array in types with jnp.ndarray. (#26703)

`jnp.array` is a function, not a type:
https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html
so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`.

Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
Roy Hvaara
2023-10-10 12:35:16 -07:00
committed by GitHub
parent 3eceaa3637
commit fc63914399
25 changed files with 28 additions and 28 deletions

View File

@@ -288,7 +288,7 @@ def create_train_state(
def create_learning_rate_fn(
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
) -> Callable[[int], jnp.array]:
) -> Callable[[int], jnp.ndarray]:
"""Returns a linear warmup, linear_decay learning rate function."""
steps_per_epoch = train_ds_size // train_batch_size
num_train_steps = steps_per_epoch * num_train_epochs