[JAX] Replace all jax.tree_* calls with jax.tree_util.tree_* (#18361)

* [JAX] Replace all jax.tree_* calls with jax.tree_util.tree_*

* fix double tree_util
This commit is contained in:
Sanchit Gandhi
2022-09-09 14:18:56 +01:00
committed by GitHub
parent 22f7218560
commit e6f221c8d4
17 changed files with 49 additions and 49 deletions

View File

@@ -1011,7 +1011,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir), params=params)
tokenizer.save_pretrained(os.path.join(training_args.output_dir, ckpt_dir))
if training_args.push_to_hub:
@@ -1064,7 +1064,7 @@ def main():
if metrics:
# normalize metrics
metrics = get_metrics(metrics)
metrics = jax.tree_map(jnp.mean, metrics)
metrics = jax.tree_util.tree_map(jnp.mean, metrics)
# compute ROUGE metrics
generations = []

View File

@@ -781,7 +781,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])
@@ -824,7 +824,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: jnp.mean(x).item(), eval_metrics)
try:
eval_metrics["perplexity"] = math.exp(eval_metrics["loss"])

View File

@@ -827,9 +827,9 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.sum, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.sum, eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
# Update progress bar
epochs.desc = f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})"
@@ -841,7 +841,7 @@ def main():
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@@ -867,9 +867,9 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.sum(metric).item(), eval_metrics)
eval_normalizer = eval_metrics.pop("normalizer")
eval_metrics = jax.tree_map(lambda x: x / eval_normalizer, eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda x: x / eval_normalizer, eval_metrics)
try:
perplexity = math.exp(eval_metrics["loss"])

View File

@@ -940,7 +940,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Update progress bar
epochs.write(f"Step... ({cur_step} | Loss: {eval_metrics['loss']}, Acc: {eval_metrics['accuracy']})")
@@ -952,7 +952,7 @@ def main():
if cur_step % training_args.save_steps == 0 and cur_step > 0:
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@@ -978,7 +978,7 @@ def main():
# get eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
eval_metrics = jax.tree_util.tree_map(lambda metric: jnp.mean(metric).item(), eval_metrics)
if jax.process_index() == 0:
eval_metrics = {f"eval_{metric_name}": value for metric_name, value in eval_metrics.items()}

View File

@@ -902,7 +902,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# compute ROUGE metrics
rouge_desc = ""
@@ -923,7 +923,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
tokenizer.save_pretrained(training_args.output_dir)
if training_args.push_to_hub:
@@ -957,7 +957,7 @@ def main():
# normalize prediction metrics
pred_metrics = get_metrics(pred_metrics)
pred_metrics = jax.tree_map(jnp.mean, pred_metrics)
pred_metrics = jax.tree_util.tree_map(jnp.mean, pred_metrics)
# compute ROUGE metrics
rouge_desc = ""

View File

@@ -542,7 +542,7 @@ def main():
# normalize eval metrics
eval_metrics = get_metrics(eval_metrics)
eval_metrics = jax.tree_map(jnp.mean, eval_metrics)
eval_metrics = jax.tree_util.tree_map(jnp.mean, eval_metrics)
# Print metrics and update progress bar
eval_step_progress_bar.close()
@@ -560,7 +560,7 @@ def main():
# save checkpoint after each epoch and push checkpoint to the hub
if jax.process_index() == 0:
params = jax.device_get(jax.tree_map(lambda x: x[0], state.params))
params = jax.device_get(jax.tree_util.tree_map(lambda x: x[0], state.params))
model.save_pretrained(training_args.output_dir, params=params)
if training_args.push_to_hub:
repo.push_to_hub(commit_message=f"Saving weights and logs of epoch {epoch}", blocking=False)