Add timing inside Trainer (#9196)
* Add timing inside Trainer * Fix tests * Add n_objs for train * Sort logs
This commit is contained in:
@@ -22,6 +22,7 @@ import math
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -89,6 +90,7 @@ from .trainer_utils import (
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import logging
|
||||
@@ -707,6 +709,7 @@ class Trainer:
|
||||
logger.info(f" Total optimization steps = {max_steps}")
|
||||
|
||||
self.state.epoch = 0
|
||||
start_time = time.time()
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
|
||||
@@ -870,15 +873,17 @@ class Trainer:
|
||||
state_dict = torch.load(os.path.join(self.state.best_model_checkpoint, WEIGHTS_NAME))
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
metrics = speed_metrics("train", start_time, self.state.max_steps)
|
||||
if self._total_flos is not None:
|
||||
self.store_flos()
|
||||
self.log({"total_flos": self.state.total_flos})
|
||||
metrics["total_flos"] = self.state.total_flos
|
||||
self.log(metrics)
|
||||
|
||||
self.control = self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
# add remaining tr_loss
|
||||
self._total_loss_scalar += tr_loss.item()
|
||||
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step)
|
||||
return TrainOutput(self.state.global_step, self._total_loss_scalar / self.state.global_step, metrics)
|
||||
|
||||
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
|
||||
if self.control.should_log:
|
||||
@@ -1317,6 +1322,7 @@ class Trainer:
|
||||
raise ValueError("eval_dataset must implement __len__")
|
||||
|
||||
eval_dataloader = self.get_eval_dataloader(eval_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
output = self.prediction_loop(
|
||||
eval_dataloader,
|
||||
@@ -1328,6 +1334,8 @@ class Trainer:
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
|
||||
n_samples = len(eval_dataset if eval_dataset is not None else self.eval_dataset)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, n_samples))
|
||||
self.log(output.metrics)
|
||||
|
||||
if self.args.tpu_metrics_debug or self.args.debug:
|
||||
@@ -1374,10 +1382,13 @@ class Trainer:
|
||||
raise ValueError("test_dataset must implement __len__")
|
||||
|
||||
test_dataloader = self.get_test_dataloader(test_dataset)
|
||||
start_time = time.time()
|
||||
|
||||
return self.prediction_loop(
|
||||
output = self.prediction_loop(
|
||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, len(test_dataset)))
|
||||
return output
|
||||
|
||||
def prediction_loop(
|
||||
self,
|
||||
|
||||
@@ -18,6 +18,7 @@ Utilities for the Trainer and TFTrainer class. Should be independent from PyTorc
|
||||
|
||||
import copy
|
||||
import random
|
||||
import time
|
||||
from typing import Any, Dict, NamedTuple, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -70,6 +71,7 @@ class PredictionOutput(NamedTuple):
|
||||
class TrainOutput(NamedTuple):
|
||||
global_step: int
|
||||
training_loss: float
|
||||
metrics: Dict[str, float]
|
||||
|
||||
|
||||
PREFIX_CHECKPOINT_DIR = "checkpoint"
|
||||
@@ -179,3 +181,23 @@ def total_processes_number(local_rank):
|
||||
|
||||
return torch.distributed.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def speed_metrics(split, start_time, num_samples=None):
|
||||
"""
|
||||
Measure and return speed performance metrics.
|
||||
|
||||
This function requires a time snapshot `start_time` before the operation to be measured starts and this function
|
||||
should be run immediately after the operation to be measured has completed.
|
||||
|
||||
Args:
|
||||
- split: name to prefix metric (like train, eval, test...)
|
||||
- start_time: operation start time
|
||||
- num_samples: number of samples processed
|
||||
"""
|
||||
runtime = time.time() - start_time
|
||||
result = {f"{split}_runtime": round(runtime, 4)}
|
||||
if num_samples is not None:
|
||||
samples_per_second = 1 / (runtime / num_samples)
|
||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||
return result
|
||||
|
||||
@@ -12,10 +12,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
@@ -411,7 +410,16 @@ class TrainingArguments:
|
||||
self.run_name = self.output_dir
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and self.fp16:
|
||||
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
|
||||
raise ValueError("Mixed precision training with AMP or APEX (`--fp16`) can only be used on CUDA devices.")
|
||||
|
||||
def __repr__(self):
|
||||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
|
||||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
|
||||
self_as_dict = asdict(self)
|
||||
del self_as_dict["per_gpu_train_batch_size"]
|
||||
del self_as_dict["per_gpu_eval_batch_size"]
|
||||
attrs_as_str = [f"{k}={v}" for k, v in self_as_dict.items()]
|
||||
return f"{self.__class__.__name__}({', '.join(attrs_as_str)})"
|
||||
|
||||
@property
|
||||
def train_batch_size(self) -> int:
|
||||
@@ -523,7 +531,7 @@ class TrainingArguments:
|
||||
"""
|
||||
Serializes this instance while replace `Enum` by their values (for JSON serialization support).
|
||||
"""
|
||||
d = dataclasses.asdict(self)
|
||||
d = asdict(self)
|
||||
for k, v in d.items():
|
||||
if isinstance(v, Enum):
|
||||
d[k] = v.value
|
||||
|
||||
Reference in New Issue
Block a user