[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi
2022-09-09 14:18:56 +01:00
committed by GitHub
parent 22f7218560
commit e6f221c8d4
17 changed files with 49 additions and 49 deletions

View File

@@ -1011,7 +1011,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
if training_args.push_to_hub:
@@ -1064,7 +1064,7 @@ def main():
if metrics:
# normalize metrics
metrics = get_metrics(metrics)
metrics = jax.tree_map(jnp.mean, metrics)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
# compute ROUGE metrics
generations = []