[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -542,7 +542,7 @@ def main():
|
||||
|
||||
# normalize eval metrics
|
||||
eval_metrics = get_metrics(eval_metrics)
|
||||
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
|
||||
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
|
||||
|
||||
# Print metrics and update progress bar
|
||||
eval_step_progress_bar.close()
|
||||
@@ -560,7 +560,7 @@ def main():
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(training_args.output_dir, params=params)
|
||||
if training_args.push_to_hub:
|
||||
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)
|
||||
|
||||
Reference in New Issue
Block a user