This commit is contained in:
Patrick von Platen
2020-12-27 21:57:50 +01:00
committed by GitHub
parent 61443cd7d9
commit 8e74eca7f2

View File

@@ -171,7 +171,9 @@ class Seq2SeqTrainer(Trainer):
"""
if not self.args.predict_with_generate or prediction_loss_only:
return super()(self, model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys)
return super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
)
has_labels = "labels" in inputs
inputs = self._prepare_inputs(inputs)