Add predict step accumulation (#7767)

* Add eval_accumulation_step and clean distributed eval

* Add TPU test

* Add TPU stuff

* Fix arg name

* Fix Seq2SeqTrainer

* Fix total_size

* Update src/transformers/trainer_pt_utils.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Doc and add test to TPU

* Add unit test

* Adapt name

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Sylvain Gugger
2020-10-14 11:41:45 -04:00
committed by GitHub
parent 8feb0cc967
commit a1d1b332d0
10 changed files with 413 additions and 47 deletions

View File

@@ -59,6 +59,7 @@ from .trainer_callback import (
TrainerState,
)
from .trainer_pt_utils import (
DistributedTensorGatherer,
SequentialDistributedSampler,
distributed_broadcast_scalars,
distributed_concat,
@@ -1266,18 +1267,29 @@ class Trainer:
# multi-gpu eval
if self.args.n_gpu > 1:
model = torch.nn.DataParallel(model)
else:
model = self.model
# Note: in torch.distributed mode, there's no point in wrapping the model
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
batch_size = dataloader.batch_size
num_examples = self.num_examples(dataloader)
logger.info("***** Running %s *****", description)
logger.info(" Num examples = %d", self.num_examples(dataloader))
logger.info(" Num examples = %d", num_examples)
logger.info(" Batch size = %d", batch_size)
eval_losses: List[float] = []
preds: torch.Tensor = None
label_ids: torch.Tensor = None
losses_host: torch.Tensor = None
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
world_size = 1
if is_torch_tpu_available():
world_size = xm.xrt_world_size()
elif self.args.local_rank != -1:
world_size = torch.distributed.get_world_size()
world_size = max(1, world_size)
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
model.eval()
if is_torch_tpu_available():
@@ -1288,55 +1300,46 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader
for inputs in dataloader:
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
batch_size = inputs[list(inputs.keys())[0]].shape[0]
if loss is not None:
eval_losses.extend([loss] * batch_size)
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
if labels is not None:
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Set back to None to begin a new accumulation
losses_host, preds_host, labels_host = None, None, None
if self.args.past_index and hasattr(self, "_past"):
# Clean the state at the end of the evaluation loop
delattr(self, "_past")
if self.args.local_rank != -1:
# In distributed mode, concatenate all results from all nodes:
if preds is not None:
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
if label_ids is not None:
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
elif is_torch_tpu_available():
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
if preds is not None:
preds = nested_xla_mesh_reduce(preds, "eval_preds")
if label_ids is not None:
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
if eval_losses is not None:
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
# Gather all remaining tensors and put them back on the CPU
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
# Finally, turn the aggregated tensors into numpy arrays.
if preds is not None:
preds = nested_numpify(preds)
if label_ids is not None:
label_ids = nested_numpify(label_ids)
eval_loss = eval_losses_gatherer.finalize()
preds = preds_gatherer.finalize()
label_ids = labels_gatherer.finalize()
if self.compute_metrics is not None and preds is not None and label_ids is not None:
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
else:
metrics = {}
if len(eval_losses) > 0:
if self.args.local_rank != -1:
metrics["eval_loss"] = (
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
.mean()
.item()
)
else:
metrics["eval_loss"] = np.mean(eval_losses)
if eval_loss is not None:
metrics["eval_loss"] = eval_loss.mean().item()
# Prefix all keys with eval_
for key in list(metrics.keys()):
@@ -1345,6 +1348,20 @@ class Trainer:
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
def _gather_and_numpify(self, tensors, name):
"""
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
concatenating them to `gathered`
"""
if tensors is None:
return
if is_torch_tpu_available():
tensors = nested_xla_mesh_reduce(tensors, name)
elif self.args.local_rank != -1:
tensors = distributed_concat(tensors)
return nested_numpify(tensors)
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
@@ -1374,8 +1391,7 @@ class Trainer:
with torch.no_grad():
outputs = model(**inputs)
if has_labels:
# The .mean() is to reduce in case of distributed training
loss = outputs[0].mean().item()
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None