[trainer] move secondary methods into a separate file (#10363)
* move secondary methods into a separate file * cleanup * style
This commit is contained in:
@@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune
|
|||||||
import collections
|
import collections
|
||||||
import gc
|
import gc
|
||||||
import inspect
|
import inspect
|
||||||
import json
|
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -82,7 +81,6 @@ from .trainer_pt_utils import (
|
|||||||
SequentialDistributedSampler,
|
SequentialDistributedSampler,
|
||||||
distributed_broadcast_scalars,
|
distributed_broadcast_scalars,
|
||||||
distributed_concat,
|
distributed_concat,
|
||||||
get_learning_rate,
|
|
||||||
nested_concat,
|
nested_concat,
|
||||||
nested_detach,
|
nested_detach,
|
||||||
nested_numpify,
|
nested_numpify,
|
||||||
@@ -226,6 +224,8 @@ class Trainer:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
model: Union[PreTrainedModel, torch.nn.Module] = None,
|
||||||
@@ -1130,7 +1130,7 @@ class Trainer:
|
|||||||
tr_loss -= tr_loss
|
tr_loss -= tr_loss
|
||||||
|
|
||||||
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
|
||||||
logs["learning_rate"] = get_learning_rate(self)
|
logs["learning_rate"] = self._get_learning_rate()
|
||||||
|
|
||||||
self._total_loss_scalar += tr_loss_scalar
|
self._total_loss_scalar += tr_loss_scalar
|
||||||
self._globalstep_last_logged = self.state.global_step
|
self._globalstep_last_logged = self.state.global_step
|
||||||
@@ -1345,61 +1345,6 @@ class Trainer:
|
|||||||
self.state.log_history.append(output)
|
self.state.log_history.append(output)
|
||||||
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs)
|
||||||
|
|
||||||
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
Reformat Trainer metrics values to a human-readable format
|
|
||||||
|
|
||||||
Args:
|
|
||||||
metrics (:obj:`Dict[str, float]`):
|
|
||||||
The metrics returned from train/evaluate/predict
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
|
||||||
"""
|
|
||||||
|
|
||||||
metrics_copy = metrics.copy()
|
|
||||||
for k, v in metrics_copy.items():
|
|
||||||
if "_mem_" in k:
|
|
||||||
metrics_copy[k] = f"{ v >> 20 }MB"
|
|
||||||
elif k == "total_flos":
|
|
||||||
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
|
||||||
elif type(metrics_copy[k]) == float:
|
|
||||||
metrics_copy[k] = round(v, 4)
|
|
||||||
|
|
||||||
return metrics_copy
|
|
||||||
|
|
||||||
def log_metrics(self, split, metrics):
|
|
||||||
"""
|
|
||||||
Log metrics in a specially formatted way
|
|
||||||
|
|
||||||
Args:
|
|
||||||
split (:obj:`str`):
|
|
||||||
Mode/split name: one of ``train``, ``eval``, ``test``
|
|
||||||
metrics (:obj:`Dict[str, float]`):
|
|
||||||
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
|
||||||
"""
|
|
||||||
|
|
||||||
logger.info(f"***** {split} metrics *****")
|
|
||||||
metrics_formatted = self.metrics_format(metrics)
|
|
||||||
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
|
||||||
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
|
||||||
for key in sorted(metrics_formatted.keys()):
|
|
||||||
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
|
||||||
|
|
||||||
def save_metrics(self, split, metrics):
|
|
||||||
"""
|
|
||||||
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
split (:obj:`str`):
|
|
||||||
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
|
||||||
metrics (:obj:`Dict[str, float]`):
|
|
||||||
The metrics returned from train/evaluate/predict
|
|
||||||
"""
|
|
||||||
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
|
||||||
with open(path, "w") as f:
|
|
||||||
json.dump(metrics, f, indent=4, sort_keys=True)
|
|
||||||
|
|
||||||
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
|
||||||
"""
|
"""
|
||||||
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and
|
||||||
|
|||||||
@@ -16,11 +16,13 @@
|
|||||||
Torch utilities for the Trainer class.
|
Torch utilities for the Trainer class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterator, List, Optional, Union
|
from typing import Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -263,29 +265,6 @@ def _get_first_shape(arrays):
|
|||||||
return arrays.shape
|
return arrays.shape
|
||||||
|
|
||||||
|
|
||||||
def get_learning_rate(trainer):
|
|
||||||
if trainer.deepspeed:
|
|
||||||
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
|
||||||
# not run for the first few dozen steps while loss scale is too large, and thus during
|
|
||||||
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
|
||||||
try:
|
|
||||||
last_lr = trainer.lr_scheduler.get_last_lr()[0]
|
|
||||||
except AssertionError as e:
|
|
||||||
if "need to call step" in str(e):
|
|
||||||
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
|
||||||
last_lr = 0
|
|
||||||
else:
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
last_lr = (
|
|
||||||
# backward compatibility for pytorch schedulers
|
|
||||||
trainer.lr_scheduler.get_last_lr()[0]
|
|
||||||
if version.parse(torch.__version__) >= version.parse("1.4")
|
|
||||||
else trainer.lr_scheduler.get_lr()[0]
|
|
||||||
)
|
|
||||||
return last_lr
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedTensorGatherer:
|
class DistributedTensorGatherer:
|
||||||
"""
|
"""
|
||||||
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks.
|
||||||
@@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
|||||||
assert len(indices) == self.num_samples
|
assert len(indices) == self.num_samples
|
||||||
|
|
||||||
return iter(indices)
|
return iter(indices)
|
||||||
|
|
||||||
|
|
||||||
|
# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer
|
||||||
|
# helper methods here
|
||||||
|
|
||||||
|
|
||||||
|
def _get_learning_rate(self):
|
||||||
|
if self.deepspeed:
|
||||||
|
# with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may
|
||||||
|
# not run for the first few dozen steps while loss scale is too large, and thus during
|
||||||
|
# that time `get_last_lr` will fail if called during that warm up stage, so work around it:
|
||||||
|
try:
|
||||||
|
last_lr = self.lr_scheduler.get_last_lr()[0]
|
||||||
|
except AssertionError as e:
|
||||||
|
if "need to call step" in str(e):
|
||||||
|
logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0")
|
||||||
|
last_lr = 0
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
last_lr = (
|
||||||
|
# backward compatibility for pytorch schedulers
|
||||||
|
self.lr_scheduler.get_last_lr()[0]
|
||||||
|
if version.parse(torch.__version__) >= version.parse("1.4")
|
||||||
|
else self.lr_scheduler.get_lr()[0]
|
||||||
|
)
|
||||||
|
return last_lr
|
||||||
|
|
||||||
|
|
||||||
|
def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
Reformat Trainer metrics values to a human-readable format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics returned from train/evaluate/predict
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
metrics (:obj:`Dict[str, float]`): The reformatted metrics
|
||||||
|
"""
|
||||||
|
|
||||||
|
metrics_copy = metrics.copy()
|
||||||
|
for k, v in metrics_copy.items():
|
||||||
|
if "_mem_" in k:
|
||||||
|
metrics_copy[k] = f"{ v >> 20 }MB"
|
||||||
|
elif k == "total_flos":
|
||||||
|
metrics_copy[k] = f"{ int(v) >> 30 }GF"
|
||||||
|
elif type(metrics_copy[k]) == float:
|
||||||
|
metrics_copy[k] = round(v, 4)
|
||||||
|
|
||||||
|
return metrics_copy
|
||||||
|
|
||||||
|
|
||||||
|
def log_metrics(self, split, metrics):
|
||||||
|
"""
|
||||||
|
Log metrics in a specially formatted way
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (:obj:`str`):
|
||||||
|
Mode/split name: one of ``train``, ``eval``, ``test``
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics returned from train/evaluate/predictmetrics: metrics dict
|
||||||
|
"""
|
||||||
|
|
||||||
|
logger.info(f"***** {split} metrics *****")
|
||||||
|
metrics_formatted = self.metrics_format(metrics)
|
||||||
|
k_width = max(len(str(x)) for x in metrics_formatted.keys())
|
||||||
|
v_width = max(len(str(x)) for x in metrics_formatted.values())
|
||||||
|
for key in sorted(metrics_formatted.keys()):
|
||||||
|
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
|
||||||
|
|
||||||
|
|
||||||
|
def save_metrics(self, split, metrics):
|
||||||
|
"""
|
||||||
|
Save metrics into a json file for that split, e.g. ``train_results.json``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
split (:obj:`str`):
|
||||||
|
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
|
||||||
|
metrics (:obj:`Dict[str, float]`):
|
||||||
|
The metrics returned from train/evaluate/predict
|
||||||
|
"""
|
||||||
|
path = os.path.join(self.args.output_dir, f"{split}_results.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(metrics, f, indent=4, sort_keys=True)
|
||||||
|
|||||||
Reference in New Issue
Block a user