From 55bcd0cb5909fdbd1b9e1c1123dcfe972f123db9 Mon Sep 17 00:00:00 2001 From: Bram Vanroy Date: Mon, 19 Oct 2020 21:59:30 +0200 Subject: [PATCH] Raise error when using AMP on non-CUDA device (#7869) * Raise error when using AMP on non-CUDA device * make style * make style --- src/transformers/training_args.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 17ea24ba2e..b86a1cbc2b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -351,6 +351,9 @@ class TrainingArguments: if self.run_name is None: self.run_name = self.output_dir + if self.device.type != "cuda" and self.fp16: + raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.") + @property def train_batch_size(self) -> int: """