Floating-point operations logging in trainer (#6768)

* neFLOs calculation, logging, and reloading (#1)

* testing distributed consecutive batches

* fixed AttributeError from DataParallel

* removed verbosity

* rotate with use_mtime=True

* removed print

* fixed interaction with gradient accumulation

* indent formatting

* distributed neflo counting

* fixed typo

* fixed typo

* mean distributed losses

* exporting log history

* moved a few functions

* floating_point_ops clarification for transformers with parameter-reuse

* code quality

* double import

* made flo estimation more task-agnostic

* only logging flos if computed

* code quality

* unused import

* Update src/transformers/trainer.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Sylvain review

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* black

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Teven
2020-09-08 16:00:56 +02:00
committed by GitHub
parent d155b38d6e
commit 01d340adfa
3 changed files with 187 additions and 39 deletions

View File

@@ -1,4 +1,5 @@
import inspect
import json
import math
import os
import re
@@ -42,6 +43,8 @@ from .trainer_utils import (
TrainOutput,
default_compute_objective,
default_hp_space,
distributed_broadcast_scalars,
distributed_concat,
set_seed,
)
from .training_args import TrainingArguments
@@ -146,7 +149,7 @@ class SequentialDistributedSampler(Sampler):
indices = indices[self.rank * self.num_samples : (self.rank + 1) * self.num_samples]
assert (
len(indices) == self.num_samples
), f"Indices length {len(indices)} and and sample number {self.num_samples} mismatched"
), f"Indices length {len(indices)} and sample number {self.num_samples} mismatched"
return iter(indices)
@@ -241,6 +244,7 @@ class Trainer:
"You should subclass `Trainer` and override the `create_optimizer_and_scheduler` method."
)
self.tb_writer = tb_writer
self.log_history = []
if "prediction_loss_only" in kwargs:
warnings.warn(
"Passing `prediction_loss_only` as a keyword argument is deprecated and won't be possible in a future version. Use `args.prediction_loss_only` instead.",
@@ -284,6 +288,7 @@ class Trainer:
self.global_step = None
self.epoch = None
self.total_flos = None
if self.args.fp16 and _use_native_amp:
self.scaler = torch.cuda.amp.GradScaler()
self.hp_search_backend = None
@@ -461,7 +466,11 @@ class Trainer:
logger.info(
'Automatic Weights & Biases logging enabled, to disable set os.environ["WANDB_DISABLED"] = "true"'
)
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
try:
combined_dict = {**self.model.config.to_dict(), **self.args.to_sanitized_dict()}
except AttributeError:
# in case the model has no config
combined_dict = {**self.args.to_sanitized_dict()}
wandb.init(
project=os.getenv("WANDB_PROJECT", "huggingface"), config=combined_dict, name=self.args.run_name
)
@@ -663,6 +672,7 @@ class Trainer:
self.global_step = 0
self.epoch = 0
self.total_flos = 0
epochs_trained = 0
steps_trained_in_current_epoch = 0
# Check if continuing training from a checkpoint
@@ -670,6 +680,8 @@ class Trainer:
# set global_step to global_step of last saved checkpoint from model path
try:
self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])
self.total_flos = getattr(model.config, "total_flos", 0)
epochs_trained = self.global_step // (len(train_dataloader) // self.args.gradient_accumulation_steps)
steps_trained_in_current_epoch = self.global_step % (
len(train_dataloader) // self.args.gradient_accumulation_steps
@@ -678,9 +690,11 @@ class Trainer:
logger.info(" Continuing training from checkpoint, will skip to saved global_step")
logger.info(" Continuing training from epoch %d", epochs_trained)
logger.info(" Continuing training from global step %d", self.global_step)
logger.info(" Continuing training from %d non-embedding floating-point operations", self.total_flos)
logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
except ValueError:
self.global_step = 0
self.total_flos = 0
logger.info(" Starting fine-tuning.")
tr_loss = torch.tensor(0.0).to(self.args.device)
@@ -714,6 +728,7 @@ class Trainer:
continue
tr_loss += self.training_step(model, inputs)
self.total_flos += self.floating_point_ops(inputs)
if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
# last step in epoch but step is always smaller than gradient_accumulation_steps
@@ -784,7 +799,7 @@ class Trainer:
self.save_model(output_dir)
if self.is_world_process_zero():
self._rotate_checkpoints()
self._rotate_checkpoints(use_mtime=True)
if is_torch_tpu_available():
xm.rendezvous("saving_optimizer_states")
@@ -924,6 +939,13 @@ class Trainer:
if self.epoch is not None:
logs["epoch"] = self.epoch
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
logs["total_flos"] = self.total_flos
if self.global_step is None:
# when logging evaluation metrics without training
self.global_step = 0
@@ -951,6 +973,8 @@ class Trainer:
if experiment is not None:
experiment._log_metrics(logs, step=self.global_step, epoch=self.epoch, framework="transformers")
output = {**logs, **{"step": self.global_step}}
if self.is_world_process_zero():
self.log_history.append(output)
if iterator is not None:
iterator.write(output)
else:
@@ -1089,6 +1113,9 @@ class Trainer:
if xm.is_master_ordinal():
os.makedirs(output_dir, exist_ok=True)
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
@@ -1096,6 +1123,7 @@ class Trainer:
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
xm.rendezvous("saving_checkpoint")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
@@ -1108,12 +1136,26 @@ class Trainer:
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
raise ValueError("Trainer.model appears to not be a PreTrainedModel")
self._store_flos()
self.model.save_pretrained(output_dir)
if self.tokenizer is not None:
self.tokenizer.save_pretrained(output_dir)
# Good practice: save your training arguments together with the trained model
torch.save(self.args, os.path.join(output_dir, "training_args.bin"))
json.dump(
self.log_history, open(os.path.join(output_dir, "log_history.json"), "w"), indent=2, ensure_ascii=False
)
def _store_flos(self):
# Storing the number of floating-point operations that went into the model
if self.total_flos is not None:
if self.args.local_rank != -1:
total_flos = distributed_broadcast_scalars([self.total_flos]).sum().item()
else:
total_flos = self.total_flos
if total_flos > 0:
self.model.config.total_flos = total_flos
def _sorted_checkpoints(self, checkpoint_prefix=PREFIX_CHECKPOINT_DIR, use_mtime=False) -> List[str]:
ordering_and_checkpoint_path = []
@@ -1245,13 +1287,11 @@ class Trainer:
self._past = None
disable_tqdm = not self.is_local_process_zero() or self.args.disable_tqdm
samples_count = 0
for inputs in tqdm(dataloader, desc=description, disable=disable_tqdm):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
samples_count += batch_size
if loss is not None:
eval_losses.append(loss * batch_size)
eval_losses.extend([loss] * batch_size)
if logits is not None:
preds = logits if preds is None else torch.cat((preds, logits), dim=0)
if labels is not None:
@@ -1264,9 +1304,9 @@ class Trainer:
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = self.distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = self.distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
@@ -1289,7 +1329,14 @@ class Trainer:
else:
metrics = {}
if len(eval_losses) > 0:
metrics["eval_loss"] = np.sum(eval_losses) / samples_count
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)
# Prefix all keys with eval_
for key in list(metrics.keys()):
@@ -1298,18 +1345,6 @@ class Trainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def distributed_concat(self, tensor: torch.Tensor, num_total_examples: int) -> torch.Tensor:
assert self.args.local_rank != -1
output_tensors = [tensor.clone() for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(output_tensors, tensor)
concat = torch.cat(output_tensors, dim=0)
# truncate the dummy elements added by SequentialDistributedSampler
output = concat[:num_total_examples]
return output
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
@@ -1355,3 +1390,32 @@ class Trainer:
if labels is not None:
labels = labels.detach()
return (loss, logits.detach(), labels)
def floating_point_ops(self, inputs: Dict[str, Union[torch.Tensor, Any]]):
"""
For models that inherit from :class:`~transformers.PretrainedModel`, uses
that method to compute the number of floating point operations for every backward + forward pass. If using
another model, either implement such a method in the model or subclass and override this method.
Args:
model (:obj:`nn.Module`):
The model to evaluate.
inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
The inputs and targets of the model.
Returns:
:obj:`int`: The number of floating-point operations.
"""
if isinstance(self.model, torch.nn.DataParallel) or isinstance(
self.model, torch.nn.parallel.DistributedDataParallel
):
model = self.model.module
else:
model = self.model
if hasattr(model, "floating_point_ops"):
return model.floating_point_ops(inputs)
else:
return 0