Make Trainer evaluation handle dynamic seq_length (#8336)

* Make Trainer evaluation handle dynamic seq_length

* Document behavior.

* Fix test

* Better fix

* Fixes for realsies this time

* Address review comments

* Without forgetting to save...
This commit is contained in:
Sylvain Gugger
2020-11-05 15:13:51 -05:00
committed by GitHub
parent 27b402cab0
commit 04e442d575
3 changed files with 135 additions and 13 deletions

View File

@@ -1333,6 +1333,12 @@ class Trainer:
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
.. note::
If your predictions or labels have different sequence length (for instance because you're doing dynamic
padding in a token classification task) the predictions will be padded (on the right) to allow for
concatenation into one array. The padding index is -100.
Returns: `NamedTuple` A namedtuple with the following keys:
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
@@ -1412,9 +1418,9 @@ class Trainer:
losses = loss.repeat(batch_size)
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
if logits is not None:
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
if labels is not None:
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.