From 6287c929c162a417cd5d355c9113e9908710858d Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Tue, 25 May 2021 08:11:26 -0700 Subject: [PATCH] [lm examples] fix overflow in perplexity calc (#11855) * fix overflow in perplexity calc * use inf * fix --- examples/pytorch/language-modeling/run_clm.py | 5 ++++- examples/pytorch/language-modeling/run_clm_no_trainer.py | 5 ++++- examples/pytorch/language-modeling/run_mlm.py | 5 ++++- examples/pytorch/language-modeling/run_mlm_no_trainer.py | 5 ++++- examples/pytorch/language-modeling/run_plm.py | 5 ++++- 5 files changed, 20 insertions(+), 5 deletions(-) diff --git a/examples/pytorch/language-modeling/run_clm.py b/examples/pytorch/language-modeling/run_clm.py index 7aed40ed83..c3bf39ffce 100755 --- a/examples/pytorch/language-modeling/run_clm.py +++ b/examples/pytorch/language-modeling/run_clm.py @@ -440,7 +440,10 @@ def main(): max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - perplexity = math.exp(metrics["eval_loss"]) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") metrics["perplexity"] = perplexity trainer.log_metrics("eval", metrics) diff --git a/examples/pytorch/language-modeling/run_clm_no_trainer.py b/examples/pytorch/language-modeling/run_clm_no_trainer.py index 4584724667..4005e7883c 100755 --- a/examples/pytorch/language-modeling/run_clm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_clm_no_trainer.py @@ -442,7 +442,10 @@ def main(): losses = torch.cat(losses) losses = losses[: len(eval_dataset)] - perplexity = math.exp(torch.mean(losses)) + try: + perplexity = math.exp(torch.mean(losses)) + except OverflowError: + perplexity = float("inf") logger.info(f"epoch {epoch}: perplexity: {perplexity}") diff --git a/examples/pytorch/language-modeling/run_mlm.py b/examples/pytorch/language-modeling/run_mlm.py index 32a4bb537f..60d315ef5f 100755 --- a/examples/pytorch/language-modeling/run_mlm.py +++ b/examples/pytorch/language-modeling/run_mlm.py @@ -469,7 +469,10 @@ def main(): max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - perplexity = math.exp(metrics["eval_loss"]) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") metrics["perplexity"] = perplexity trainer.log_metrics("eval", metrics) diff --git a/examples/pytorch/language-modeling/run_mlm_no_trainer.py b/examples/pytorch/language-modeling/run_mlm_no_trainer.py index 1cf1c242ab..1731b244da 100755 --- a/examples/pytorch/language-modeling/run_mlm_no_trainer.py +++ b/examples/pytorch/language-modeling/run_mlm_no_trainer.py @@ -486,7 +486,10 @@ def main(): losses = torch.cat(losses) losses = losses[: len(eval_dataset)] - perplexity = math.exp(torch.mean(losses)) + try: + perplexity = math.exp(torch.mean(losses)) + except OverflowError: + perplexity = float("inf") logger.info(f"epoch {epoch}: perplexity: {perplexity}") diff --git a/examples/pytorch/language-modeling/run_plm.py b/examples/pytorch/language-modeling/run_plm.py index f5cace2b6b..e8fab3c394 100755 --- a/examples/pytorch/language-modeling/run_plm.py +++ b/examples/pytorch/language-modeling/run_plm.py @@ -445,7 +445,10 @@ def main(): max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset) metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset)) - perplexity = math.exp(metrics["eval_loss"]) + try: + perplexity = math.exp(metrics["eval_loss"]) + except OverflowError: + perplexity = float("inf") metrics["perplexity"] = perplexity trainer.log_metrics("eval", metrics)