Fix distributed evaluation (#10795)
* Fix distributed evaluation * Use logger
This commit is contained in:
@@ -690,7 +690,7 @@ class Trainer:
|
|||||||
"""
|
"""
|
||||||
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
Helper to get number of samples in a :class:`~torch.utils.data.DataLoader` by accessing its dataset.
|
||||||
|
|
||||||
Will raise an exception if the underlying dataset dese not implement method :obj:`__len__`
|
Will raise an exception if the underlying dataset does not implement method :obj:`__len__`
|
||||||
"""
|
"""
|
||||||
return len(dataloader.dataset)
|
return len(dataloader.dataset)
|
||||||
|
|
||||||
@@ -1812,8 +1812,13 @@ class Trainer:
|
|||||||
|
|
||||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||||
if not prediction_loss_only:
|
if not prediction_loss_only:
|
||||||
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
# The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
|
||||||
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
# a batch size to the sampler)
|
||||||
|
make_multiple_of = None
|
||||||
|
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, SequentialDistributedSampler):
|
||||||
|
make_multiple_of = dataloader.sampler.batch_size
|
||||||
|
preds_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
|
||||||
|
labels_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=make_multiple_of)
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
|
|||||||
@@ -97,6 +97,11 @@ if __name__ == "__main__":
|
|||||||
def compute_metrics(p: EvalPrediction) -> Dict:
|
def compute_metrics(p: EvalPrediction) -> Dict:
|
||||||
sequential = list(range(len(dataset)))
|
sequential = list(range(len(dataset)))
|
||||||
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
|
success = p.predictions.tolist() == sequential and p.label_ids.tolist() == sequential
|
||||||
|
if not success and training_args.local_rank == 0:
|
||||||
|
logger.warning(
|
||||||
|
"Predictions and/or labels do not match expected results:\n - predictions: "
|
||||||
|
f"{p.predictions.tolist()}\n - labels: {p.label_ids.tolist()}\n - expected: {sequential}"
|
||||||
|
)
|
||||||
return {"success": success}
|
return {"success": success}
|
||||||
|
|
||||||
trainer = Trainer(
|
trainer = Trainer(
|
||||||
|
|||||||
Reference in New Issue
Block a user