[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi
2022-09-09 14:18:56 +01:00
committed by GitHub
parent 22f7218560
commit e6f221c8d4
17 changed files with 49 additions and 49 deletions

View File

@@ -674,9 +674,9 @@ if __name__ == "__main__":
eval_metrics.append(metrics)
eval_metrics_np = get_metrics(eval_metrics)
eval_metrics_np = jax.tree_map(jnp.sum, eval_metrics_np)
eval_metrics_np = jax.tree_util.tree_map(jnp.sum, eval_metrics_np)
eval_normalizer = eval_metrics_np.pop("normalizer")
eval_summary = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
eval_summary = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics_np)
# Update progress bar
epochs.desc = (