Make Trainer evaluation handle dynamic seq_length (#8336)
* Make Trainer evaluation handle dynamic seq_length * Document behavior. * Fix test * Better fix * Fixes for realsies this time * Address review comments * Without forgetting to save...
This commit is contained in:
@@ -1333,6 +1333,12 @@ class Trainer:
|
||||
Dataset to run the predictions on. If it is an :obj:`datasets.Dataset`, columns not accepted by the
|
||||
``model.forward()`` method are automatically removed. Has to implement the method :obj:`__len__`
|
||||
|
||||
.. note::
|
||||
|
||||
If your predictions or labels have different sequence length (for instance because you're doing dynamic
|
||||
padding in a token classification task) the predictions will be padded (on the right) to allow for
|
||||
concatenation into one array. The padding index is -100.
|
||||
|
||||
Returns: `NamedTuple` A namedtuple with the following keys:
|
||||
|
||||
- predictions (:obj:`np.ndarray`): The predictions on :obj:`test_dataset`.
|
||||
@@ -1412,9 +1418,9 @@ class Trainer:
|
||||
losses = loss.repeat(batch_size)
|
||||
losses_host = losses if losses_host is None else torch.cat((losses_host, losses), dim=0)
|
||||
if logits is not None:
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, dim=0)
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
if labels is not None:
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, dim=0)
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
|
||||
Reference in New Issue
Block a user