diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 12f0496b53..42a9ee2fbc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -53,6 +53,7 @@ from .utils import ( logging, requires_backends, ) +from .utils.generic import strtobool from .utils.import_utils import is_optimum_neuron_available @@ -1720,7 +1721,7 @@ class TrainingArguments: self.distributed_state = None if not self.use_ipex and "ACCELERATE_USE_IPEX" not in os.environ: os.environ["ACCELERATE_USE_IPEX"] = "false" - if self.use_cpu or os.environ.get("ACCELERATE_USE_CPU", False): + if self.use_cpu or strtobool(os.environ.get("ACCELERATE_USE_CPU", "False")): self.distributed_state = PartialState(cpu=True, backend=self.ddp_backend) self._n_gpu = 0 elif is_sagemaker_mp_enabled():