From bdbb2c756b87aea8e03107add321432f4815b107 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Wed, 24 Feb 2021 08:32:52 -0800 Subject: [PATCH] [trainer] move secondary methods into a separate file (#10363) * move secondary methods into a separate file * cleanup * style --- src/transformers/trainer.py | 61 +-------------- src/transformers/trainer_pt_utils.py | 112 +++++++++++++++++++++------ 2 files changed, 91 insertions(+), 82 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5e805b62db..2f2030a9dc 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -19,7 +19,6 @@ The Trainer class, to easily train a 🤗 Transformers from scratch or finetune import collections import gc import inspect -import json import math import os import re @@ -82,7 +81,6 @@ from .trainer_pt_utils import ( SequentialDistributedSampler, distributed_broadcast_scalars, distributed_concat, - get_learning_rate, nested_concat, nested_detach, nested_numpify, @@ -226,6 +224,8 @@ class Trainer: """ + from .trainer_pt_utils import _get_learning_rate, log_metrics, metrics_format, save_metrics + def __init__( self, model: Union[PreTrainedModel, torch.nn.Module] = None, @@ -1130,7 +1130,7 @@ class Trainer: tr_loss -= tr_loss logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4) - logs["learning_rate"] = get_learning_rate(self) + logs["learning_rate"] = self._get_learning_rate() self._total_loss_scalar += tr_loss_scalar self._globalstep_last_logged = self.state.global_step @@ -1345,61 +1345,6 @@ class Trainer: self.state.log_history.append(output) self.control = self.callback_handler.on_log(self.args, self.state, self.control, logs) - def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]: - """ - Reformat Trainer metrics values to a human-readable format - - Args: - metrics (:obj:`Dict[str, float]`): - The metrics returned from train/evaluate/predict - - Returns: - metrics (:obj:`Dict[str, float]`): The reformatted metrics - """ - - metrics_copy = metrics.copy() - for k, v in metrics_copy.items(): - if "_mem_" in k: - metrics_copy[k] = f"{ v >> 20 }MB" - elif k == "total_flos": - metrics_copy[k] = f"{ int(v) >> 30 }GF" - elif type(metrics_copy[k]) == float: - metrics_copy[k] = round(v, 4) - - return metrics_copy - - def log_metrics(self, split, metrics): - """ - Log metrics in a specially formatted way - - Args: - split (:obj:`str`): - Mode/split name: one of ``train``, ``eval``, ``test`` - metrics (:obj:`Dict[str, float]`): - The metrics returned from train/evaluate/predictmetrics: metrics dict - """ - - logger.info(f"***** {split} metrics *****") - metrics_formatted = self.metrics_format(metrics) - k_width = max(len(str(x)) for x in metrics_formatted.keys()) - v_width = max(len(str(x)) for x in metrics_formatted.values()) - for key in sorted(metrics_formatted.keys()): - logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") - - def save_metrics(self, split, metrics): - """ - Save metrics into a json file for that split, e.g. ``train_results.json``. - - Args: - split (:obj:`str`): - Mode/split name: one of ``train``, ``eval``, ``test``, ``all`` - metrics (:obj:`Dict[str, float]`): - The metrics returned from train/evaluate/predict - """ - path = os.path.join(self.args.output_dir, f"{split}_results.json") - with open(path, "w") as f: - json.dump(metrics, f, indent=4, sort_keys=True) - def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]: """ Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index ce4d400cc9..eac696ec35 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -16,11 +16,13 @@ Torch utilities for the Trainer class. """ +import json import math +import os import warnings from contextlib import contextmanager from dataclasses import dataclass -from typing import Iterator, List, Optional, Union +from typing import Dict, Iterator, List, Optional, Union import numpy as np import torch @@ -263,29 +265,6 @@ def _get_first_shape(arrays): return arrays.shape -def get_learning_rate(trainer): - if trainer.deepspeed: - # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may - # not run for the first few dozen steps while loss scale is too large, and thus during - # that time `get_last_lr` will fail if called during that warm up stage, so work around it: - try: - last_lr = trainer.lr_scheduler.get_last_lr()[0] - except AssertionError as e: - if "need to call step" in str(e): - logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0") - last_lr = 0 - else: - raise - else: - last_lr = ( - # backward compatibility for pytorch schedulers - trainer.lr_scheduler.get_last_lr()[0] - if version.parse(torch.__version__) >= version.parse("1.4") - else trainer.lr_scheduler.get_lr()[0] - ) - return last_lr - - class DistributedTensorGatherer: """ A class responsible for properly gathering tensors (or nested list/tuple of tensors) on the CPU by chunks. @@ -563,3 +542,88 @@ class DistributedLengthGroupedSampler(DistributedSampler): assert len(indices) == self.num_samples return iter(indices) + + +# In order to keep `trainer.py` compact and easy to understand, place any secondary PT Trainer +# helper methods here + + +def _get_learning_rate(self): + if self.deepspeed: + # with deepspeed's fp16 and dynamic loss scale enabled the optimizer/scheduler steps may + # not run for the first few dozen steps while loss scale is too large, and thus during + # that time `get_last_lr` will fail if called during that warm up stage, so work around it: + try: + last_lr = self.lr_scheduler.get_last_lr()[0] + except AssertionError as e: + if "need to call step" in str(e): + logger.warn("tried to get lr value before scheduler/optimizer started stepping, returning lr=0") + last_lr = 0 + else: + raise + else: + last_lr = ( + # backward compatibility for pytorch schedulers + self.lr_scheduler.get_last_lr()[0] + if version.parse(torch.__version__) >= version.parse("1.4") + else self.lr_scheduler.get_lr()[0] + ) + return last_lr + + +def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]: + """ + Reformat Trainer metrics values to a human-readable format + + Args: + metrics (:obj:`Dict[str, float]`): + The metrics returned from train/evaluate/predict + + Returns: + metrics (:obj:`Dict[str, float]`): The reformatted metrics + """ + + metrics_copy = metrics.copy() + for k, v in metrics_copy.items(): + if "_mem_" in k: + metrics_copy[k] = f"{ v >> 20 }MB" + elif k == "total_flos": + metrics_copy[k] = f"{ int(v) >> 30 }GF" + elif type(metrics_copy[k]) == float: + metrics_copy[k] = round(v, 4) + + return metrics_copy + + +def log_metrics(self, split, metrics): + """ + Log metrics in a specially formatted way + + Args: + split (:obj:`str`): + Mode/split name: one of ``train``, ``eval``, ``test`` + metrics (:obj:`Dict[str, float]`): + The metrics returned from train/evaluate/predictmetrics: metrics dict + """ + + logger.info(f"***** {split} metrics *****") + metrics_formatted = self.metrics_format(metrics) + k_width = max(len(str(x)) for x in metrics_formatted.keys()) + v_width = max(len(str(x)) for x in metrics_formatted.values()) + for key in sorted(metrics_formatted.keys()): + logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}") + + +def save_metrics(self, split, metrics): + """ + Save metrics into a json file for that split, e.g. ``train_results.json``. + + Args: + split (:obj:`str`): + Mode/split name: one of ``train``, ``eval``, ``test``, ``all`` + metrics (:obj:`Dict[str, float]`): + The metrics returned from train/evaluate/predict + """ + path = os.path.join(self.args.output_dir, f"{split}_results.json") + with open(path, "w") as f: + json.dump(metrics, f, indent=4, sort_keys=True)