deprecate use_mps_device (#24239)
This commit is contained in:
committed by
GitHub
parent
3e142cb0f5
commit
3723329d01
@@ -581,7 +581,7 @@ class TrainingArguments:
|
||||
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
|
||||
information.
|
||||
use_mps_device (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use Apple Silicon chip based `mps` device.
|
||||
This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.
|
||||
torch_compile (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to compile the model using PyTorch 2.0
|
||||
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).
|
||||
@@ -780,7 +780,11 @@ class TrainingArguments:
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||
use_mps_device: bool = field(
|
||||
default=False, metadata={"help": "Whether to use Apple Silicon chip based `mps` device."}
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device."
|
||||
" It will be removed in version 5.0 of 🤗 Transformers"
|
||||
},
|
||||
)
|
||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
|
||||
@@ -1714,29 +1718,18 @@ class TrainingArguments:
|
||||
pass
|
||||
elif self.distributed_state.distributed_type == DistributedType.NO:
|
||||
if self.use_mps_device:
|
||||
if not torch.backends.mps.is_available():
|
||||
if not torch.backends.mps.is_built():
|
||||
raise AssertionError(
|
||||
"MPS not available because the current PyTorch install was not "
|
||||
"built with MPS enabled. Please install torch version >=1.12.0 on "
|
||||
"your Apple silicon Mac running macOS 12.3 or later with a native "
|
||||
"version (arm64) of Python"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
"MPS not available because the current MacOS version is not 12.3+ "
|
||||
"and/or you do not have an MPS-enabled device on this machine."
|
||||
)
|
||||
else:
|
||||
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
|
||||
warnings.warn(
|
||||
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
|
||||
" on your MacOS machine. It has major fixes related to model correctness and performance"
|
||||
" improvements for transformer based models. Please refer to"
|
||||
" https://github.com/pytorch/pytorch/issues/82707 for more details."
|
||||
)
|
||||
device = torch.device("mps")
|
||||
self._n_gpu = 1
|
||||
warnings.warn(
|
||||
"`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers."
|
||||
"`mps` device will be used by default if available similar to the way `cuda` device is used."
|
||||
"Therefore, no action from user is required. "
|
||||
)
|
||||
if device.type != "mps":
|
||||
raise ValueError(
|
||||
"Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
|
||||
"or current PyTorch install was not built with MPS enabled."
|
||||
)
|
||||
if device.type == "mps":
|
||||
self._n_gpu = 1
|
||||
elif self.no_cuda:
|
||||
device = torch.device("cpu")
|
||||
self._n_gpu = 0
|
||||
|
||||
Reference in New Issue
Block a user