Ensure metric results are JSON-serializable (#10632)
This commit is contained in:
@@ -101,6 +101,7 @@ from .trainer_utils import (
|
|||||||
TrainOutput,
|
TrainOutput,
|
||||||
default_compute_objective,
|
default_compute_objective,
|
||||||
default_hp_space,
|
default_hp_space,
|
||||||
|
denumpify_detensorize,
|
||||||
get_last_checkpoint,
|
get_last_checkpoint,
|
||||||
set_seed,
|
set_seed,
|
||||||
speed_metrics,
|
speed_metrics,
|
||||||
@@ -1831,6 +1832,9 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
metrics = {}
|
metrics = {}
|
||||||
|
|
||||||
|
# To be JSON-serializable, we need to remove numpy types or zero-d tensors
|
||||||
|
metrics = denumpify_detensorize(metrics)
|
||||||
|
|
||||||
if eval_loss is not None:
|
if eval_loss is not None:
|
||||||
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
|
metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,13 @@ from .file_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if is_torch_available():
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if is_tf_available():
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
def set_seed(seed: int):
|
def set_seed(seed: int):
|
||||||
"""
|
"""
|
||||||
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
|
||||||
@@ -49,14 +56,10 @@ def set_seed(seed: int):
|
|||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
np.random.seed(seed)
|
np.random.seed(seed)
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
import torch
|
|
||||||
|
|
||||||
torch.manual_seed(seed)
|
torch.manual_seed(seed)
|
||||||
torch.cuda.manual_seed_all(seed)
|
torch.cuda.manual_seed_all(seed)
|
||||||
# ^^ safe to call this function even if cuda is not available
|
# ^^ safe to call this function even if cuda is not available
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
import tensorflow as tf
|
|
||||||
|
|
||||||
tf.random.set_seed(seed)
|
tf.random.set_seed(seed)
|
||||||
|
|
||||||
|
|
||||||
@@ -423,6 +426,21 @@ class TrainerMemoryTracker:
|
|||||||
self.update_metrics(stage, metrics)
|
self.update_metrics(stage, metrics)
|
||||||
|
|
||||||
|
|
||||||
|
def denumpify_detensorize(metrics):
|
||||||
|
"""
|
||||||
|
Recursively calls `.item()` on the element of the dictionary passed
|
||||||
|
"""
|
||||||
|
if isinstance(metrics, (list, tuple)):
|
||||||
|
return type(metrics)(denumpify_detensorize(m) for m in metrics)
|
||||||
|
elif isinstance(metrics, dict):
|
||||||
|
return type(metrics)({k: denumpify_detensorize(v) for k, v in metrics.items()})
|
||||||
|
elif isinstance(metrics, np.generic):
|
||||||
|
return metrics.item()
|
||||||
|
elif is_torch_available() and isinstance(metrics, torch.Tensor) and metrics.numel() == 1:
|
||||||
|
return metrics.item()
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
class ShardedDDPOption(ExplicitEnum):
|
class ShardedDDPOption(ExplicitEnum):
|
||||||
SIMPLE = "simple"
|
SIMPLE = "simple"
|
||||||
ZERO_DP_2 = "zero_dp_2"
|
ZERO_DP_2 = "zero_dp_2"
|
||||||
|
|||||||
Reference in New Issue
Block a user