Floating-point operations logging in trainer (#6768)
* neFLOs calculation, logging, and reloading (#1) * testing distributed consecutive batches * fixed AttributeError from DataParallel * removed verbosity * rotate with use_mtime=True * removed print * fixed interaction with gradient accumulation * indent formatting * distributed neflo counting * fixed typo * fixed typo * mean distributed losses * exporting log history * moved a few functions * floating_point_ops clarification for transformers with parameter-reuse * code quality * double import * made flo estimation more task-agnostic * only logging flos if computed * code quality * unused import * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Sylvain review * Update src/transformers/modeling_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * black Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import inspect
|
||||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
@@ -42,6 +43,8 @@ from .trainer_utils import (
|
||||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
set_seed,
|
||||
)
|
||||
from .training_args import TrainingArguments
|
||||
@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
|
||||
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
|
||||
assert (
|
||||
len(indices) == self.num_samples
|
||||
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
|
||||
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
|
||||
|
||||
return iter(indices)
|
||||
|
||||
@@ -241,6 +244,7 @@ class Trainer:
|
||||
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
|
||||
)
|
||||
self.tb_writer = tb_writer
|
||||
self.log_history = []
|
||||
if "prediction_loss_only" in kwargs:
|
||||
warnings.warn(
|
||||
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
|
||||
@@ -284,6 +288,7 @@ class Trainer:
|
||||
|
||||
self.global_step = None
|
||||
self.epoch = None
|
||||
self.total_flos = None
|
||||
if self.args.fp16 and _use_native_amp:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.hp_search_backend = None
|
||||
@@ -461,7 +466,11 @@ class Trainer:
|
||||
logger.info(
|
||||
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
|
||||
)
|
||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||
try:
|
||||
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
|
||||
except AttributeError:
|
||||
# in case the model has no config
|
||||
combined_dict = {**self.args.to_sanitized_dict()}
|
||||
wandb.init(
|
||||
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
|
||||
)
|
||||
@@ -663,6 +672,7 @@ class Trainer:
|
||||
|
||||
self.global_step = 0
|
||||
self.epoch = 0
|
||||
self.total_flos = 0
|
||||
epochs_trained = 0
|
||||
steps_trained_in_current_epoch = 0
|
||||
# Check if continuing training from a checkpoint
|
||||
@@ -670,6 +680,8 @@ class Trainer:
|
||||
# set global_step to global_step of last saved checkpoint from model path
|
||||
try:
|
||||
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
|
||||
self.total_flos = getattr(model.config, "total_flos", 0)
|
||||
|
||||
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
|
||||
steps_trained_in_current_epoch = self.global_step % (
|
||||
len(train_dataloader) // self.args.gradient_accumulation_steps
|
||||
@@ -678,9 +690,11 @@ class Trainer:
|
||||
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
|
||||
logger.info(" Continuing training from epoch %d", epochs_trained)
|
||||
logger.info(" Continuing training from global step %d", self.global_step)
|
||||
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
|
||||
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
|
||||
except ValueError:
|
||||
self.global_step = 0
|
||||
self.total_flos = 0
|
||||
logger.info(" Starting fine-tuning.")
|
||||
|
||||
tr_loss = torch.tensor(0.0).to(self.args.device)
|
||||
@@ -714,6 +728,7 @@ class Trainer:
|
||||
continue
|
||||
|
||||
tr_loss += self.training_step(model, inputs)
|
||||
self.total_flos += self.floating_point_ops(inputs)
|
||||
|
||||
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
|
||||
# last step in epoch but step is always smaller than gradient_accumulation_steps
|
||||
@@ -784,7 +799,7 @@ class Trainer:
|
||||
self.save_model(output_dir)
|
||||
|
||||
if self.is_world_process_zero():
|
||||
self._rotate_checkpoints()
|
||||
self._rotate_checkpoints(use_mtime=True)
|
||||
|
||||
if is_torch_tpu_available():
|
||||
xm.rendezvous("saving_optimizer_states")
|
||||
@@ -924,6 +939,13 @@ class Trainer:
|
||||
|
||||
if self.epoch is not None:
|
||||
logs["epoch"] = self.epoch
|
||||
if self.total_flos is not None:
|
||||
if self.args.local_rank != -1:
|
||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
logs["total_flos"] = self.total_flos
|
||||
if self.global_step is None:
|
||||
# when logging evaluation metrics without training
|
||||
self.global_step = 0
|
||||
@@ -951,6 +973,8 @@ class Trainer:
|
||||
if experiment is not None:
|
||||
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
|
||||
output = {**logs, **{"step": self.global_step}}
|
||||
if self.is_world_process_zero():
|
||||
self.log_history.append(output)
|
||||
if iterator is not None:
|
||||
iterator.write(output)
|
||||
else:
|
||||
@@ -1089,6 +1113,9 @@ class Trainer:
|
||||
if xm.is_master_ordinal():
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
json.dump(
|
||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
# Save a trained model and configuration using `save_pretrained()`.
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
@@ -1096,6 +1123,7 @@ class Trainer:
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
|
||||
xm.rendezvous("saving_checkpoint")
|
||||
self._store_flos()
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
@@ -1108,12 +1136,26 @@ class Trainer:
|
||||
# They can then be reloaded using `from_pretrained()`
|
||||
if not isinstance(self.model, PreTrainedModel):
|
||||
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
|
||||
self._store_flos()
|
||||
self.model.save_pretrained(output_dir)
|
||||
if self.tokenizer is not None:
|
||||
self.tokenizer.save_pretrained(output_dir)
|
||||
|
||||
# Good practice: save your training arguments together with the trained model
|
||||
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
|
||||
json.dump(
|
||||
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
|
||||
)
|
||||
|
||||
def _store_flos(self):
|
||||
# Storing the number of floating-point operations that went into the model
|
||||
if self.total_flos is not None:
|
||||
if self.args.local_rank != -1:
|
||||
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
|
||||
else:
|
||||
total_flos = self.total_flos
|
||||
if total_flos > 0:
|
||||
self.model.config.total_flos = total_flos
|
||||
|
||||
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
|
||||
ordering_and_checkpoint_path = []
|
||||
@@ -1245,13 +1287,11 @@ class Trainer:
|
||||
self._past = None
|
||||
|
||||
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
|
||||
samples_count = 0
|
||||
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
samples_count += batch_size
|
||||
if loss is not None:
|
||||
eval_losses.append(loss * batch_size)
|
||||
eval_losses.extend([loss] * batch_size)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
|
||||
if labels is not None:
|
||||
@@ -1264,9 +1304,9 @@ class Trainer:
|
||||
if self.args.local_rank != -1:
|
||||
# In distributed mode, concatenate all results from all nodes:
|
||||
if preds is not None:
|
||||
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
if label_ids is not None:
|
||||
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
@@ -1289,7 +1329,14 @@ class Trainer:
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
|
||||
if self.args.local_rank != -1:
|
||||
metrics["eval_loss"] = (
|
||||
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
metrics["eval_loss"] = np.mean(eval_losses)
|
||||
|
||||
# Prefix all keys with eval_
|
||||
for key in list(metrics.keys()):
|
||||
@@ -1298,18 +1345,6 @@ class Trainer:
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
|
||||
assert self.args.local_rank != -1
|
||||
|
||||
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
|
||||
torch.distributed.all_gather(output_tensors, tensor)
|
||||
|
||||
concat = torch.cat(output_tensors, dim=0)
|
||||
|
||||
# truncate the dummy elements added by SequentialDistributedSampler
|
||||
output = concat[:num_total_examples]
|
||||
return output
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
@@ -1355,3 +1390,32 @@ class Trainer:
|
||||
if labels is not None:
|
||||
labels = labels.detach()
|
||||
return (loss, logits.detach(), labels)
|
||||
|
||||
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
|
||||
"""
|
||||
For models that inherit from :class:`~transformers.PretrainedModel`, uses
|
||||
that method to compute the number of floating point operations for every backward + forward pass. If using
|
||||
another model, either implement such a method in the model or subclass and override this method.
|
||||
|
||||
Args:
|
||||
model (:obj:`nn.Module`):
|
||||
The model to evaluate.
|
||||
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
|
||||
The inputs and targets of the model.
|
||||
|
||||
Returns:
|
||||
:obj:`int`: The number of floating-point operations.
|
||||
"""
|
||||
|
||||
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
|
||||
self.model, torch.nn.parallel.DistributedDataParallel
|
||||
):
|
||||
model = self.model.module
|
||||
else:
|
||||
model = self.model
|
||||
|
||||
if hasattr(model, "floating_point_ops"):
|
||||
return model.floating_point_ops(inputs)
|
||||
|
||||
else:
|
||||
return 0
|
||||
|
||||
Reference in New Issue
Block a user