From c153bcc5c86014cdf821872a5b3ecc2d3109e046 Mon Sep 17 00:00:00 2001 From: luyug Date: Mon, 26 Oct 2020 08:12:31 -0400 Subject: [PATCH] Add mixed precision evaluation (#8036) * Add mixed precision evaluation * use original flag --- src/transformers/trainer.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 400527a225..e1d1947ecb 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1466,7 +1466,11 @@ class Trainer: inputs = self._prepare_inputs(inputs) with torch.no_grad(): - outputs = model(**inputs) + if self.args.fp16 and _use_native_amp: + with autocast(): + outputs = model(**inputs) + else: + outputs = model(**inputs) if has_labels: loss = outputs[0].mean().detach() logits = outputs[1:]