Add timing inside Trainer (#9196)
* Add timing inside Trainer * Fix tests * Add n_objs for train * Sort logs
This commit is contained in:
@@ -22,6 +22,7 @@ import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -89,6 +90,7 @@ from .trainer_utils import (
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
@@ -707,6 +709,7 @@ class Trainer:
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
self.state.epoch = 0
|
||||
start_time = time.time()
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
@@ -870,15 +873,17 @@ class Trainer:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
metrics = speed_metrics("train", start_time, self.state.max_steps)
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
self.log({"total_flos": self.state.total_flos})
|
||||
metrics["total_flos"] = self.state.total_flos
|
||||
self.log(metrics)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
@@ -1317,6 +1322,7 @@ class Trainer:
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
output = self.prediction_loop(
|
||||
eval_dataloader,
|
||||
@@ -1328,6 +1334,8 @@ class Trainer:
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
|
||||
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
|
||||
self.log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
@@ -1374,10 +1382,13 @@ class Trainer:
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
return self.prediction_loop(
|
||||
output = self.prediction_loop(
|
||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
|
||||
return output
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user