Better filtering of the model outputs in Trainer (#8633)

* Better filtering of the model outputs in Trainer

* Fix examples tests

* Add test for Lysandre
This commit is contained in:
Sylvain Gugger
2020-11-19 10:43:15 -05:00
committed by GitHub
parent f2e07e7272
commit 4208f496ee
16 changed files with 119 additions and 15 deletions

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
@@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer):
return loss
def prediction_step(
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
prediction_loss_only: bool,
ignore_keys: Optional[List[str]] = None,
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Perform an evaluation step on :obj:`model` using obj:`inputs`.