[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)
* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* * fix double tree_util
This commit is contained in:
@@ -1011,7 +1011,7 @@ def main():
|
||||
|
||||
# save checkpoint after each epoch and push checkpoint to the hub
|
||||
if jax.process_index() == 0:
|
||||
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
|
||||
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
|
||||
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
|
||||
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
|
||||
if training_args.push_to_hub:
|
||||
@@ -1064,7 +1064,7 @@ def main():
|
||||
if metrics:
|
||||
# normalize metrics
|
||||
metrics = get_metrics(metrics)
|
||||
metrics = jax.tree_map(jnp.mean, metrics)
|
||||
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
|
||||
|
||||
# compute ROUGE metrics
|
||||
generations = []
|
||||
|
||||
Reference in New Issue
Block a user