Correct Logging of Eval metric to Tensorboard (#16825)
* Correct Logging of Eval metric to Tensorboard An empty dictionary ``eval_metrics`` was being logged, is replaced by ``eval_metric`` which is the output dictionary of ``metric.compute()``. * Remove unused variable
This commit is contained in:
@@ -592,7 +592,6 @@ def main():
|
|||||||
|
|
||||||
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
if (cur_step % training_args.eval_steps == 0 or cur_step % steps_per_epoch == 0) and cur_step > 0:
|
||||||
|
|
||||||
eval_metrics = {}
|
|
||||||
# evaluate
|
# evaluate
|
||||||
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
eval_loader = glue_eval_data_collator(eval_dataset, eval_batch_size)
|
||||||
for batch in tqdm(
|
for batch in tqdm(
|
||||||
@@ -623,7 +622,7 @@ def main():
|
|||||||
logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
|
logger.info(f"Step... ({cur_step}/{total_steps} | Eval metrics: {eval_metric})")
|
||||||
|
|
||||||
if has_tensorboard and jax.process_index() == 0:
|
if has_tensorboard and jax.process_index() == 0:
|
||||||
write_eval_metric(summary_writer, eval_metrics, cur_step)
|
write_eval_metric(summary_writer, eval_metric, cur_step)
|
||||||
|
|
||||||
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
|
if (cur_step % training_args.save_steps == 0 and cur_step > 0) or (cur_step == total_steps):
|
||||||
# save checkpoint after each epoch and push checkpoint to the hub
|
# save checkpoint after each epoch and push checkpoint to the hub
|
||||||
|
|||||||
Reference in New Issue
Block a user