[trainer] make generate work with multigpu (#8716)

* make generate work with multigpu

* better fix - thanks @sgugger
This commit is contained in:
Stas Bekman
2020-11-23 10:57:27 -08:00
committed by GitHub
parent 900024273b
commit 1e45bef0a7
2 changed files with 2 additions and 9 deletions

View File

@@ -189,7 +189,7 @@ class Seq2SeqTrainer(Trainer):
}
if self.args.predict_with_generate and not self.args.prediction_loss_only:
generated_tokens = model.generate(
generated_tokens = self.model.generate(
inputs["input_ids"],
attention_mask=inputs["attention_mask"],
**gen_kwargs,