deprecate use_mps_device (#24239)

This commit is contained in:
Sourab Mangrulkar
2023-06-13 19:48:36 +05:30
committed by GitHub
parent 3e142cb0f5
commit 3723329d01
2 changed files with 20 additions and 27 deletions

View File

@@ -581,7 +581,7 @@ class TrainingArguments:
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
Whether to use Apple Silicon chip based `mps` device.
This argument is deprecated.`mps` device will be used if it is available similar to `cuda` device.
torch_compile (`bool`, *optional*, defaults to `False`):
Whether or not to compile the model using PyTorch 2.0
[`torch.compile`](https://pytorch.org/get-started/pytorch-2.0/).
@@ -780,7 +780,11 @@ class TrainingArguments:
)
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
use_mps_device: bool = field(
default=False, metadata={"help": "Whether to use Apple Silicon chip based `mps` device."}
default=False,
metadata={
"help": "This argument is deprecated. `mps` device will be used if available similar to `cuda` device."
" It will be removed in version 5.0 of 🤗 Transformers"
},
)
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."})
@@ -1714,29 +1718,18 @@ class TrainingArguments:
pass
elif self.distributed_state.distributed_type == DistributedType.NO:
if self.use_mps_device:
if not torch.backends.mps.is_available():
if not torch.backends.mps.is_built():
raise AssertionError(
"MPS not available because the current PyTorch install was not "
"built with MPS enabled. Please install torch version >=1.12.0 on "
"your Apple silicon Mac running macOS 12.3 or later with a native "
"version (arm64) of Python"
)
else:
raise AssertionError(
"MPS not available because the current MacOS version is not 12.3+ "
"and/or you do not have an MPS-enabled device on this machine."
)
else:
if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"):
warnings.warn(
"We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)"
" on your MacOS machine. It has major fixes related to model correctness and performance"
" improvements for transformer based models. Please refer to"
" https://github.com/pytorch/pytorch/issues/82707 for more details."
)
device = torch.device("mps")
self._n_gpu = 1
warnings.warn(
"`use_mps_device` is deprecated and will be removed in version 5.0 of 🤗 Transformers."
"`mps` device will be used by default if available similar to the way `cuda` device is used."
"Therefore, no action from user is required. "
)
if device.type != "mps":
raise ValueError(
"Either you do not have an MPS-enabled device on this machine or MacOS version is not 12.3+ "
"or current PyTorch install was not built with MPS enabled."
)
if device.type == "mps":
self._n_gpu = 1
elif self.no_cuda:
device = torch.device("cpu")
self._n_gpu = 0