Add timing inside Trainer (#9196)

* Add timing inside Trainer

* Fix tests

* Add n_objs for train

* Sort logs
This commit is contained in:
Sylvain Gugger
2020-12-18 15:10:39 -05:00
committed by GitHub
parent 9a25c5bd3a
commit 1198ba8fba
6 changed files with 76 additions and 49 deletions

View File

@@ -22,6 +22,7 @@ import math
import os
import re
import shutil
import time
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
@@ -89,6 +90,7 @@ from .trainer_utils import (
default_compute_objective,
default_hp_space,
set_seed,
speed_metrics,
)
from .training_args import TrainingArguments
from .utils import logging
@@ -707,6 +709,7 @@ class Trainer:
logger.info(f" Total optimization steps = {max_steps}")
self.state.epoch = 0
start_time = time.time()
epochs_trained = 0
steps_trained_in_current_epoch = 0
@@ -870,15 +873,17 @@ class Trainer:
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
self.model.load_state_dict(state_dict)
metrics = speed_metrics("train", start_time, self.state.max_steps)
if self._total_flos is not None:
self.store_flos()
self.log({"total_flos": self.state.total_flos})
metrics["total_flos"] = self.state.total_flos
self.log(metrics)
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
# add remaining tr_loss
self._total_loss_scalar += tr_loss.item()
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
if self.control.should_log:
@@ -1317,6 +1322,7 @@ class Trainer:
raise ValueError("eval_dataset must implement __len__")
eval_dataloader = self.get_eval_dataloader(eval_dataset)
start_time = time.time()
output = self.prediction_loop(
eval_dataloader,
@@ -1328,6 +1334,8 @@ class Trainer:
metric_key_prefix=metric_key_prefix,
)
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
self.log(output.metrics)
if self.args.tpu_metrics_debug or self.args.debug:
@@ -1374,10 +1382,13 @@ class Trainer:
raise ValueError("test_dataset must implement __len__")
test_dataloader = self.get_test_dataloader(test_dataset)
start_time = time.time()
return self.prediction_loop(
output = self.prediction_loop(
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
)
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
return output
def prediction_loop(
self,