Raise error when using AMP on non-CUDA device (#7869)
* Raise error when using AMP on non-CUDA device * make style * make style
This commit is contained in:
@@ -351,6 +351,9 @@ class TrainingArguments:
|
|||||||
if self.run_name is None:
|
if self.run_name is None:
|
||||||
self.run_name = self.output_dir
|
self.run_name = self.output_dir
|
||||||
|
|
||||||
|
if self.device.type != "cuda" and self.fp16:
|
||||||
|
raise ValueError("AMP (`--fp16`) can only be used on CUDA devices.")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def train_batch_size(self) -> int:
|
def train_batch_size(self) -> int:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user