Ensure metric results are JSON-serializable (#10632)
This commit is contained in:
@@ -101,6 +101,7 @@ from .trainer_utils import (
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
denumpify_detensorize,
|
||||
get_last_checkpoint,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
@@ -1831,6 +1832,9 @@ class Trainer:
|
||||
else:
|
||||
metrics = {}
|
||||
|
||||
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||
metrics = denumpify_detensorize(metrics)
|
||||
|
||||
if eval_loss is not None:
|
||||
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
|
||||
|
||||
|
||||
@@ -38,6 +38,13 @@ from .file_utils import (
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""
|
||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
||||
@@ -49,14 +56,10 @@ def set_seed(seed: int):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
# ^^ safe to call this function even if cuda is not available
|
||||
if is_tf_available():
|
||||
import tensorflow as tf
|
||||
|
||||
tf.random.set_seed(seed)
|
||||
|
||||
|
||||
@@ -423,6 +426,21 @@ class TrainerMemoryTracker:
|
||||
self.update_metrics(stage, metrics)
|
||||
|
||||
|
||||
def denumpify_detensorize(metrics):
|
||||
"""
|
||||
Recursively calls `.item()` on the element of the dictionary passed
|
||||
"""
|
||||
if isinstance(metrics, (list, tuple)):
|
||||
return type(metrics)(denumpify_detensorize(m) for m in metrics)
|
||||
elif isinstance(metrics, dict):
|
||||
return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
|
||||
elif isinstance(metrics, np.generic):
|
||||
return metrics.item()
|
||||
elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
|
||||
return metrics.item()
|
||||
return metrics
|
||||
|
||||
|
||||
class ShardedDDPOption(ExplicitEnum):
|
||||
SIMPLE = "simple"
|
||||
ZERO_DP_2 = "zero_dp_2"
|
||||
|
||||
Reference in New Issue
Block a user