Raise error when using AMP on non-CUDA device (#7869)

* Raise error when using AMP on non-CUDA device

* make style

* make style
This commit is contained in:
Bram Vanroy
2020-10-19 21:59:30 +02:00
committed by GitHub
parent e3d2bee8d0
commit 55bcd0cb59

View File

@@ -351,6 +351,9 @@ class TrainingArguments:
if self.run_name is None:
self.run_name = self.output_dir
if self.device.type != "cuda" and self.fp16:
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
@property
def train_batch_size(self) -> int:
"""