Better filtering of the model outputs in Trainer (#8633)

* Better filtering of the model outputs in Trainer

* Fix examples tests

* Add test for Lysandre
This commit is contained in:
Sylvain Gugger
2020-11-19 10:43:15 -05:00
committed by GitHub
parent f2e07e7272
commit 4208f496ee
16 changed files with 119 additions and 15 deletions

View File

@@ -1098,10 +1098,11 @@ class Trainer:
"""
outputs = model(**inputs)
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index]
# We don't use .loss here since the model may return tuples instead of ModelOutput.
return outputs[0]
return outputs["loss"] if isinstance(outputs, dict) else outputs[0]
def is_local_process_zero(self) -> bool:
"""
@@ -1220,7 +1221,9 @@ class Trainer:
logger.info("Deleting older checkpoint [{}] due to args.save_total_limit".format(checkpoint))
shutil.rmtree(checkpoint)
def evaluate(self, eval_dataset: Optional[Dataset] = None) -> Dict[str, float]:
def evaluate(
self, eval_dataset: Optional[Dataset] = None, ignore_keys: Optional[List[str]] = None
) -> Dict[str, float]:
"""
Run evaluation and returns metrics.
@@ -1234,6 +1237,9 @@ class Trainer:
Pass a dataset if you wish to override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`,
columns not accepted by the ``model.forward()`` method are automatically removed. It must implement the
:obj:`__len__` method.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Returns:
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
@@ -1250,6 +1256,7 @@ class Trainer:
# No point gathering the predictions if there are no metrics, otherwise we defer to
# self.args.prediction_loss_only
prediction_loss_only=True if self.compute_metrics is None else None,
ignore_keys=ignore_keys,
)
self.log(output.metrics)
@@ -1261,7 +1268,7 @@ class Trainer:
self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
return output.metrics
def predict(self, test_dataset: Dataset) -> PredictionOutput:
def predict(self, test_dataset: Dataset, ignore_keys: Optional[List[str]] = None) -> PredictionOutput:
"""
Run prediction and returns predictions and potential metrics.
@@ -1272,6 +1279,9 @@ class Trainer:
test_dataset (:obj:`Dataset`):
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
.. note::
@@ -1291,10 +1301,14 @@ class Trainer:
test_dataloader = self.get_test_dataloader(test_dataset)
return self.prediction_loop(test_dataloader, description="Prediction")
return self.prediction_loop(test_dataloader, description="Prediction", ignore_keys=ignore_keys)
def prediction_loop(
self, dataloader: DataLoader, description: str, prediction_loss_only: Optional[bool] = None
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
) -> PredictionOutput:
"""
Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.
@@ -1346,7 +1360,7 @@ class Trainer:
self.callback_handler.eval_dataloader = dataloader
for step, inputs in enumerate(dataloader):
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only)
loss, logits, labels = self.prediction_step(model, inputs, prediction_loss_only, ignore_keys=ignore_keys)
if loss is not None:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
@@ -1410,7 +1424,11 @@ class Trainer:
return nested_numpify(tensors)
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.
@@ -1427,6 +1445,9 @@ class Trainer:
argument :obj:`labels`. Check your model's documentation for all accepted arguments.
prediction_loss_only (:obj:`bool`):
Whether or not to return the loss only.
ignore_keys (:obj:`Lst[str]`, `optional`):
A list of keys in the output of your model (if it is a dictionary) that should be ignored when
gathering predictions.
Return:
Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
@@ -1434,6 +1455,11 @@ class Trainer:
"""
has_labels = all(inputs.get(k) is not None for k in self.label_names)
inputs = self._prepare_inputs(inputs)
if ignore_keys is None:
if hasattr(self.model, "config"):
ignore_keys = getattr(self.model.config, "keys_to_ignore_at_inference", [])
else:
ignore_keys = []
with torch.no_grad():
if self.args.fp16 and _use_native_amp:
@@ -1442,16 +1468,21 @@ class Trainer:
else:
outputs = model(**inputs)
if has_labels:
loss = outputs[0].mean().detach()
logits = outputs[1:]
if isinstance(outputs, dict):
loss = outputs["loss"].mean().detach()
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
else:
loss = outputs[0].mean().detach()
logits = outputs[1:]
else:
loss = None
# Slicing so we get a tuple even if `outputs` is a `ModelOutput`.
logits = outputs[:]
if isinstance(outputs, dict):
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys)
else:
logits = outputs
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[self.args.past_index if has_labels else self.args.past_index - 1]
# Remove the past from the logits.
logits = logits[: self.args.past_index - 1] + logits[self.args.past_index :]
if prediction_loss_only:
return (loss, None, None)