Add predict step accumulation (#7767)
* Add eval_accumulation_step and clean distributed eval * Add TPU test * Add TPU stuff * Fix arg name * Fix Seq2SeqTrainer * Fix total_size * Update src/transformers/trainer_pt_utils.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Doc and add test to TPU * Add unit test * Adapt name Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -59,6 +59,7 @@ from .trainer_callback import (
|
||||
TrainerState,
|
||||
)
|
||||
from .trainer_pt_utils import (
|
||||
DistributedTensorGatherer,
|
||||
SequentialDistributedSampler,
|
||||
distributed_broadcast_scalars,
|
||||
distributed_concat,
|
||||
@@ -1266,18 +1267,29 @@ class Trainer:
|
||||
# multi-gpu eval
|
||||
if self.args.n_gpu > 1:
|
||||
model = torch.nn.DataParallel(model)
|
||||
else:
|
||||
model = self.model
|
||||
# Note: in torch.distributed mode, there's no point in wrapping the model
|
||||
# inside a DistributedDataParallel as we'll be under `no_grad` anyways.
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
num_examples = self.num_examples(dataloader)
|
||||
logger.info("***** Running %s *****", description)
|
||||
logger.info(" Num examples = %d", self.num_examples(dataloader))
|
||||
logger.info(" Num examples = %d", num_examples)
|
||||
logger.info(" Batch size = %d", batch_size)
|
||||
eval_losses: List[float] = []
|
||||
preds: torch.Tensor = None
|
||||
label_ids: torch.Tensor = None
|
||||
losses_host: torch.Tensor = None
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
|
||||
world_size = 1
|
||||
if is_torch_tpu_available():
|
||||
world_size = xm.xrt_world_size()
|
||||
elif self.args.local_rank != -1:
|
||||
world_size = torch.distributed.get_world_size()
|
||||
world_size = max(1, world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
preds_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
labels_gatherer = DistributedTensorGatherer(world_size, num_examples)
|
||||
|
||||
model.eval()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
@@ -1288,55 +1300,46 @@ class Trainer:
|
||||
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
|
||||
for inputs in dataloader:
|
||||
for step, inputs in enumerate(dataloader):
|
||||
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
|
||||
batch_size = inputs[list(inputs.keys())[0]].shape[0]
|
||||
if loss is not None:
|
||||
eval_losses.extend([loss] * batch_size)
|
||||
losses = loss.repeat(batch_size)
|
||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||
if logits is not None:
|
||||
preds = logits if preds is None else nested_concat(preds, logits, dim=0)
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
|
||||
if labels is not None:
|
||||
label_ids = labels if label_ids is None else nested_concat(label_ids, labels, dim=0)
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
if self.args.local_rank != -1:
|
||||
# In distributed mode, concatenate all results from all nodes:
|
||||
if preds is not None:
|
||||
preds = distributed_concat(preds, num_total_examples=self.num_examples(dataloader))
|
||||
if label_ids is not None:
|
||||
label_ids = distributed_concat(label_ids, num_total_examples=self.num_examples(dataloader))
|
||||
elif is_torch_tpu_available():
|
||||
# tpu-comment: Get all predictions and labels from all worker shards of eval dataset
|
||||
if preds is not None:
|
||||
preds = nested_xla_mesh_reduce(preds, "eval_preds")
|
||||
if label_ids is not None:
|
||||
label_ids = nested_xla_mesh_reduce(label_ids, "eval_label_ids")
|
||||
if eval_losses is not None:
|
||||
eval_losses = xm.mesh_reduce("eval_losses", torch.tensor(eval_losses), torch.cat).tolist()
|
||||
# Gather all remaining tensors and put them back on the CPU
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
labels_gatherer.add_arrays(self._gather_and_numpify(labels_host, "eval_label_ids"))
|
||||
|
||||
# Finally, turn the aggregated tensors into numpy arrays.
|
||||
if preds is not None:
|
||||
preds = nested_numpify(preds)
|
||||
if label_ids is not None:
|
||||
label_ids = nested_numpify(label_ids)
|
||||
eval_loss = eval_losses_gatherer.finalize()
|
||||
preds = preds_gatherer.finalize()
|
||||
label_ids = labels_gatherer.finalize()
|
||||
|
||||
if self.compute_metrics is not None and preds is not None and label_ids is not None:
|
||||
metrics = self.compute_metrics(EvalPrediction(predictions=preds, label_ids=label_ids))
|
||||
else:
|
||||
metrics = {}
|
||||
if len(eval_losses) > 0:
|
||||
if self.args.local_rank != -1:
|
||||
metrics["eval_loss"] = (
|
||||
distributed_broadcast_scalars(eval_losses, num_total_examples=self.num_examples(dataloader))
|
||||
.mean()
|
||||
.item()
|
||||
)
|
||||
else:
|
||||
metrics["eval_loss"] = np.mean(eval_losses)
|
||||
|
||||
if eval_loss is not None:
|
||||
metrics["eval_loss"] = eval_loss.mean().item()
|
||||
|
||||
# Prefix all keys with eval_
|
||||
for key in list(metrics.keys()):
|
||||
@@ -1345,6 +1348,20 @@ class Trainer:
|
||||
|
||||
return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)
|
||||
|
||||
def _gather_and_numpify(self, tensors, name):
|
||||
"""
|
||||
Gather value of `tensors` (tensor or list/tuple of nested tensors) and convert them to numpy before
|
||||
concatenating them to `gathered`
|
||||
"""
|
||||
if tensors is None:
|
||||
return
|
||||
if is_torch_tpu_available():
|
||||
tensors = nested_xla_mesh_reduce(tensors, name)
|
||||
elif self.args.local_rank != -1:
|
||||
tensors = distributed_concat(tensors)
|
||||
|
||||
return nested_numpify(tensors)
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
@@ -1374,8 +1391,7 @@ class Trainer:
|
||||
with torch.no_grad():
|
||||
outputs = model(**inputs)
|
||||
if has_labels:
|
||||
# The .mean() is to reduce in case of distributed training
|
||||
loss = outputs[0].mean().item()
|
||||
loss = outputs[0].mean().detach()
|
||||
logits = outputs[1:]
|
||||
else:
|
||||
loss = None
|
||||
|
||||
Reference in New Issue
Block a user