[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -902,7 +902,7 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# compute ROUGE metrics
|
||||
rouge_desc = ""
|
||||
@@ -923,7 +923,7 @@ def main():
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
tokenizer.save_pretrained(training_args.output_dir)
|
||||
if training_args.push_to_hub:
|
||||
@@ -957,7 +957,7 @@ def main():
|
||||
|
||||
# normalize prediction metrics
|
||||
pred_metrics = get_metrics(pred_metrics)
|
||||
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
|
||||
pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)
|
||||
|
||||
# compute ROUGE metrics
|
||||
rouge_desc = ""
|
||||
|
||||
Reference in New Issue
Block a user