From 63c295ac05962b03701bdda87a90595b5f864075 Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 11 Mar 2021 09:00:23 -0500 Subject: [PATCH] Ensure metric results are JSON-serializable (#10632) --- src/transformers/trainer.py | 4 ++++ src/transformers/trainer_utils.py | 26 ++++++++++++++++++++++---- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 42d3648c92..7e2df0bf55 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -101,6 +101,7 @@ from .trainer_utils import ( TrainOutput, default_compute_objective, default_hp_space, + denumpify_detensorize, get_last_checkpoint, set_seed, speed_metrics, @@ -1831,6 +1832,9 @@ class Trainer: else: metrics = {} + # To be JSON-serializable, we need to remove numpy types or zero-d tensors + metrics = denumpify_detensorize(metrics) + if eval_loss is not None: metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item() diff --git a/src/transformers/trainer_utils.py b/src/transformers/trainer_utils.py index d375523b06..5d7deed2e8 100644 --- a/src/transformers/trainer_utils.py +++ b/src/transformers/trainer_utils.py @@ -38,6 +38,13 @@ from .file_utils import ( ) +if is_torch_available(): + import torch + +if is_tf_available(): + import tensorflow as tf + + def set_seed(seed: int): """ Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if @@ -49,14 +56,10 @@ def set_seed(seed: int): random.seed(seed) np.random.seed(seed) if is_torch_available(): - import torch - torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) # ^^ safe to call this function even if cuda is not available if is_tf_available(): - import tensorflow as tf - tf.random.set_seed(seed) @@ -423,6 +426,21 @@ class TrainerMemoryTracker: self.update_metrics(stage, metrics) +def denumpify_detensorize(metrics): + """ + Recursively calls `.item()` on the element of the dictionary passed + """ + if isinstance(metrics, (list, tuple)): + return type(metrics)(denumpify_detensorize(m) for m in metrics) + elif isinstance(metrics, dict): + return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()}) + elif isinstance(metrics, np.generic): + return metrics.item() + elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1: + return metrics.item() + return metrics + + class ShardedDDPOption(ExplicitEnum): SIMPLE = "simple" ZERO_DP_2 = "zero_dp_2"