[JAX] Replace uses of jnp.array in types with jnp.ndarray. (#26703)
`jnp.array` is a function, not a type: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.array.html so it never makes sense to use `jnp.array` in a type annotation. Presumably the intent was to write `jnp.ndarray` aka `jax.Array`. Co-authored-by: Peter Hawkins <phawkins@google.com>
This commit is contained in:
@@ -381,7 +381,7 @@ def write_metric(summary_writer, metrics, train_time, step, metric_key_prefix="t
|
||||
|
||||
def create_learning_rate_fn(
|
||||
train_ds_size: int, train_batch_size: int, num_train_epochs: int, num_warmup_steps: int, learning_rate: float
|
||||
) -> Callable[[int], jnp.array]:
|
||||
) -> Callable[[int], jnp.ndarray]:
|
||||
"""Returns a linear warmup, linear_decay learning rate function."""
|
||||
steps_per_epoch = train_ds_size // train_batch_size
|
||||
num_train_steps = steps_per_epoch * num_train_epochs
|
||||
|
||||
Reference in New Issue
Block a user