[Trainer] Report both steps and num samples per second (#11818)
* [Trainer] Report both steps and num samples per second * Fix batch number * Update src/transformers/trainer_utils.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * Address review comments Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
This commit is contained in:
@@ -518,6 +518,7 @@ def parse_log_history(log_history):
|
||||
step = metrics.pop("step", None)
|
||||
_ = metrics.pop("eval_runtime", None)
|
||||
_ = metrics.pop("eval_samples_per_second", None)
|
||||
_ = metrics.pop("eval_steps_per_second", None)
|
||||
values = {"Training Loss": training_loss, "Epoch": epoch, "Step": step}
|
||||
for k, v in metrics.items():
|
||||
if k == "eval_loss":
|
||||
@@ -537,7 +538,7 @@ def parse_log_history(log_history):
|
||||
for key, value in log_history[idx].items():
|
||||
if key.startswith("eval_"):
|
||||
key = key[5:]
|
||||
if key not in ["runtime", "samples_per_second", "epoch", "step"]:
|
||||
if key not in ["runtime", "samples_per_second", "steps_per_second", "epoch", "step"]:
|
||||
camel_cased_key = " ".join([part.capitalize() for part in key.split("_")])
|
||||
eval_results[camel_cased_key] = value
|
||||
return train_log, lines, eval_results
|
||||
|
||||
@@ -1077,6 +1077,7 @@ class Trainer:
|
||||
# number of training epochs: num_train_epochs
|
||||
# number of training steps per epoch: num_update_steps_per_epoch
|
||||
# total number of training steps to execute: max_steps
|
||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * args.world_size
|
||||
if train_dataset_is_sized:
|
||||
num_update_steps_per_epoch = len(train_dataloader) // args.gradient_accumulation_steps
|
||||
num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
|
||||
@@ -1085,14 +1086,19 @@ class Trainer:
|
||||
num_train_epochs = args.max_steps // num_update_steps_per_epoch + int(
|
||||
args.max_steps % num_update_steps_per_epoch > 0
|
||||
)
|
||||
# May be slightly incorrect if the last batch in the training datalaoder has a smaller size but it's
|
||||
# the best we can do.
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
else:
|
||||
max_steps = math.ceil(args.num_train_epochs * num_update_steps_per_epoch)
|
||||
num_train_epochs = math.ceil(args.num_train_epochs)
|
||||
num_train_samples = len(self.train_dataset) * args.num_train_epochs
|
||||
else:
|
||||
# see __init__. max_steps is set when the dataset has no __len__
|
||||
max_steps = args.max_steps
|
||||
num_train_epochs = int(args.num_train_epochs)
|
||||
num_update_steps_per_epoch = max_steps
|
||||
num_train_samples = args.max_steps * total_train_batch_size
|
||||
|
||||
if DebugOption.UNDERFLOW_OVERFLOW in self.args.debug:
|
||||
debug_overflow = DebugUnderflowOverflow(self.model) # noqa
|
||||
@@ -1130,14 +1136,6 @@ class Trainer:
|
||||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
|
||||
|
||||
# Train!
|
||||
if is_torch_tpu_available():
|
||||
world_size = xm.xrt_world_size()
|
||||
elif args.local_rank != -1:
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
world_size = 1
|
||||
|
||||
total_train_batch_size = args.train_batch_size * args.gradient_accumulation_steps * world_size
|
||||
num_examples = (
|
||||
self.num_examples(train_dataloader) if train_dataset_is_sized else total_train_batch_size * args.max_steps
|
||||
)
|
||||
@@ -1359,7 +1357,7 @@ class Trainer:
|
||||
self.state.best_model_checkpoint, load_optimizer_states=False, load_lr_scheduler_states=False
|
||||
)
|
||||
|
||||
metrics = speed_metrics("train", start_time, self.state.max_steps)
|
||||
metrics = speed_metrics("train", start_time, num_samples=num_train_samples, num_steps=self.state.max_steps)
|
||||
self.store_flos()
|
||||
metrics["total_flos"] = self.state.total_flos
|
||||
self.log(metrics)
|
||||
@@ -2009,7 +2007,15 @@ class Trainer:
|
||||
metric_key_prefix=metric_key_prefix,
|
||||
)
|
||||
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
start_time,
|
||||
num_samples=output.num_samples,
|
||||
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
self.log(output.metrics)
|
||||
|
||||
@@ -2066,7 +2072,15 @@ class Trainer:
|
||||
output = eval_loop(
|
||||
test_dataloader, description="Prediction", ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix
|
||||
)
|
||||
output.metrics.update(speed_metrics(metric_key_prefix, start_time, output.num_samples))
|
||||
total_batch_size = self.args.eval_batch_size * self.args.world_size
|
||||
output.metrics.update(
|
||||
speed_metrics(
|
||||
metric_key_prefix,
|
||||
start_time,
|
||||
num_samples=output.num_samples,
|
||||
num_steps=math.ceil(output.num_samples / total_batch_size),
|
||||
)
|
||||
)
|
||||
|
||||
self._memory_tracker.stop_and_update_metrics(output.metrics)
|
||||
|
||||
|
||||
@@ -158,7 +158,7 @@ def default_compute_objective(metrics: Dict[str, float]) -> float:
|
||||
loss = metrics.pop("eval_loss", None)
|
||||
_ = metrics.pop("epoch", None)
|
||||
# Remove speed metrics
|
||||
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_samples_per_second")]
|
||||
speed_metrics = [m for m in metrics.keys() if m.endswith("_runtime") or m.endswith("_per_second")]
|
||||
for sm in speed_metrics:
|
||||
_ = metrics.pop(sm, None)
|
||||
return loss if len(metrics) == 0 else sum(metrics.values())
|
||||
@@ -232,7 +232,7 @@ def total_processes_number(local_rank):
|
||||
return 1
|
||||
|
||||
|
||||
def speed_metrics(split, start_time, num_samples=None):
|
||||
def speed_metrics(split, start_time, num_samples=None, num_steps=None):
|
||||
"""
|
||||
Measure and return speed performance metrics.
|
||||
|
||||
@@ -248,8 +248,11 @@ def speed_metrics(split, start_time, num_samples=None):
|
||||
runtime = time.time() - start_time
|
||||
result = {f"{split}_runtime": round(runtime, 4)}
|
||||
if num_samples is not None:
|
||||
samples_per_second = 1 / (runtime / num_samples)
|
||||
samples_per_second = num_samples / runtime
|
||||
result[f"{split}_samples_per_second"] = round(samples_per_second, 3)
|
||||
if num_steps is not None:
|
||||
steps_per_second = num_steps / runtime
|
||||
result[f"{split}_steps_per_second"] = round(steps_per_second, 3)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -327,6 +327,7 @@ class NotebookProgressCallback(TrainerCallback):
|
||||
_ = metrics.pop("epoch", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_runtime", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_samples_per_second", None)
|
||||
_ = metrics.pop(f"{metric_key_prefix}_steps_per_second", None)
|
||||
for k, v in metrics.items():
|
||||
if k == f"{metric_key_prefix}_loss":
|
||||
values["Validation Loss"] = v
|
||||
|
||||
@@ -316,6 +316,8 @@ class TrainerIntegrationCommon:
|
||||
_ = log1.pop("train_runtime", None)
|
||||
_ = log.pop("train_samples_per_second", None)
|
||||
_ = log1.pop("train_samples_per_second", None)
|
||||
_ = log.pop("train_steps_per_second", None)
|
||||
_ = log1.pop("train_steps_per_second", None)
|
||||
self.assertEqual(log, log1)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user