Better filtering of the model outputs in Trainer (#8633)
* Better filtering of the model outputs in Trainer * Fix examples tests * Add test for Lysandre
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -153,7 +153,11 @@ class Seq2SeqTrainer(Trainer):
|
||||
return loss
|
||||
|
||||
def prediction_step(
|
||||
self, model: nn.Module, inputs: Dict[str, Union[torch.Tensor, Any]], prediction_loss_only: bool
|
||||
self,
|
||||
model: nn.Module,
|
||||
inputs: Dict[str, Union[torch.Tensor, Any]],
|
||||
prediction_loss_only: bool,
|
||||
ignore_keys: Optional[List[str]] = None,
|
||||
) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Perform an evaluation step on :obj:`model` using obj:`inputs`.
|
||||
|
||||
Reference in New Issue
Block a user