Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger
2020-12-18 15:10:39 -05:00
committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
6 changed files with 76 additions and 49 deletions

View File

@@ -22,6 +22,7 @@ import math
import os
import re
import shutil
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -89,6 +90,7 @@ from .trainer_utils import (
default_compute_objective,
default_hp_space,
set_seed,
speed_metrics,
)
from .training_args import TrainingArguments
from .utils import logging
@@ -707,6 +709,7 @@ class Trainer:
logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
@@ -870,15 +873,17 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
self.log({"total_flos": self.state.total_flos})
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
@@ -1317,6 +1322,7 @@ class Trainer:
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.prediction_loop(
eval_dataloader,
@@ -1328,6 +1334,8 @@ class Trainer:
metric_key_prefix=metric_key_prefix,
)
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
@@ -1374,10 +1382,13 @@ class Trainer:
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
return self.prediction_loop(
output = self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
return output
def prediction_loop(
self,

View File

@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
import copy
import random
import time
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
import numpy as np
@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
class TrainOutput(NamedTuple):
global_step: int
training_loss: float
metrics: Dict[str, float]
PREFIX_CHECKPOINT_DIR = "checkpoint"
@@ -179,3 +181,23 @@ def total_processes_number(local_rank):
return torch.distributed.get_world_size()
return 1
def speed_metrics(split, start_time, num_samples=None):
"""
Measure and return speed performance metrics.
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
should be run immediately after the operation to be measured has completed.
Args:
- split: name to prefix metric (like train, eval, test...)
- start_time: operation start time
- num_samples: number of samples processed
"""
runtime = time.time() - start_time
result = {f"{split}_runtime": round(runtime, 4)}
if num_samples is not None:
samples_per_second = 1 / (runtime / num_samples)
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
return result

View File

@@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import dataclasses
import json
import os
from dataclasses import dataclass, field
from dataclasses import asdict, dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional, Tuple
@@ -411,7 +410,16 @@ class TrainingArguments:
self.run_name = self.output_dir
if is_torch_available() and self.device.type != "cuda" and self.fp16:
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
self_as_dict = asdict(self)
del self_as_dict["per_gpu_train_batch_size"]
del self_as_dict["per_gpu_eval_batch_size"]
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
@property
def train_batch_size(self) -> int:
@@ -523,7 +531,7 @@ class TrainingArguments:
"""
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
"""
d = dataclasses.asdict(self)
d = asdict(self)
for k, v in d.items():
if isinstance(v, Enum):
d[k] = v.value