Replaces xxx_required with requires_backends (#20715)
* Replaces xxx_required with requires_backends * Fixup
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
|
||||
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
|
||||
from .benchmark_args_utils import BenchmarkArguments
|
||||
|
||||
|
||||
@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
def _setup_devices(self) -> Tuple["torch.device", int]:
|
||||
requires_backends(self, ["torch"])
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if not self.cuda:
|
||||
device = torch.device("cpu")
|
||||
@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
||||
return is_torch_tpu_available() and self.tpu
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def device_idx(self) -> int:
|
||||
requires_backends(self, ["torch"])
|
||||
# TODO(PVP): currently only single GPU is supported
|
||||
return torch.cuda.current_device()
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def device(self) -> "torch.device":
|
||||
requires_backends(self, ["torch"])
|
||||
return self._setup_devices[0]
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def n_gpu(self):
|
||||
requires_backends(self, ["torch"])
|
||||
return self._setup_devices[1]
|
||||
|
||||
@property
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Tuple
|
||||
|
||||
from ..utils import cached_property, is_tf_available, logging, tf_required
|
||||
from ..utils import cached_property, is_tf_available, logging, requires_backends
|
||||
from .benchmark_args_utils import BenchmarkArguments
|
||||
|
||||
|
||||
@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
||||
)
|
||||
|
||||
@cached_property
|
||||
@tf_required
|
||||
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
||||
requires_backends(self, ["tf"])
|
||||
tpu = None
|
||||
if self.tpu:
|
||||
try:
|
||||
@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
||||
return tpu
|
||||
|
||||
@cached_property
|
||||
@tf_required
|
||||
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
||||
requires_backends(self, ["tf"])
|
||||
if self.is_tpu:
|
||||
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
||||
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
||||
@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
||||
return strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def is_tpu(self) -> bool:
|
||||
requires_backends(self, ["tf"])
|
||||
return self._setup_tpu is not None
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def strategy(self) -> "tf.distribute.Strategy":
|
||||
requires_backends(self, ["tf"])
|
||||
return self._setup_strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def gpu_list(self):
|
||||
requires_backends(self, ["tf"])
|
||||
return tf.config.list_physical_devices("GPU")
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_gpu(self) -> int:
|
||||
requires_backends(self, ["tf"])
|
||||
if self.cuda:
|
||||
return len(self.gpu_list)
|
||||
return 0
|
||||
|
||||
@@ -42,7 +42,7 @@ from .utils import (
|
||||
is_torch_device,
|
||||
is_torch_dtype,
|
||||
logging,
|
||||
torch_required,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
|
||||
@@ -175,7 +175,6 @@ class BatchFeature(UserDict):
|
||||
|
||||
return self
|
||||
|
||||
@torch_required
|
||||
def to(self, *args, **kwargs) -> "BatchFeature":
|
||||
"""
|
||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||
@@ -190,6 +189,7 @@ class BatchFeature(UserDict):
|
||||
Returns:
|
||||
[`BatchFeature`]: The same instance after modification.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
import torch # noqa
|
||||
|
||||
new_data = {}
|
||||
|
||||
@@ -127,10 +127,8 @@ from .utils import (
|
||||
is_vision_available,
|
||||
replace_return_docstrings,
|
||||
requires_backends,
|
||||
tf_required,
|
||||
to_numpy,
|
||||
to_py_obj,
|
||||
torch_only_method,
|
||||
torch_required,
|
||||
torch_version,
|
||||
)
|
||||
|
||||
@@ -56,8 +56,8 @@ from .utils import (
|
||||
is_torch_device,
|
||||
is_torch_tensor,
|
||||
logging,
|
||||
requires_backends,
|
||||
to_py_obj,
|
||||
torch_required,
|
||||
)
|
||||
|
||||
|
||||
@@ -739,7 +739,6 @@ class BatchEncoding(UserDict):
|
||||
|
||||
return self
|
||||
|
||||
@torch_required
|
||||
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
||||
"""
|
||||
Send all values to device by calling `v.to(device)` (PyTorch only).
|
||||
@@ -750,6 +749,7 @@ class BatchEncoding(UserDict):
|
||||
Returns:
|
||||
[`BatchEncoding`]: The same instance after modification.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||
|
||||
@@ -50,7 +50,6 @@ from .utils import (
|
||||
is_torch_tpu_available,
|
||||
logging,
|
||||
requires_backends,
|
||||
torch_required,
|
||||
)
|
||||
|
||||
|
||||
@@ -1386,8 +1385,8 @@ class TrainingArguments:
|
||||
return timedelta(seconds=self.ddp_timeout)
|
||||
|
||||
@cached_property
|
||||
@torch_required
|
||||
def _setup_devices(self) -> "torch.device":
|
||||
requires_backends(self, ["torch"])
|
||||
logger.info("PyTorch: setting up devices")
|
||||
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
|
||||
logger.warning(
|
||||
@@ -1537,15 +1536,14 @@ class TrainingArguments:
|
||||
return device
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def device(self) -> "torch.device":
|
||||
"""
|
||||
The device used by this process.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
return self._setup_devices
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def n_gpu(self):
|
||||
"""
|
||||
The number of GPUs used by this process.
|
||||
@@ -1554,12 +1552,12 @@ class TrainingArguments:
|
||||
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
||||
training. For distributed training, it will always be 1.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
# Make sure `self._n_gpu` is properly setup.
|
||||
_ = self._setup_devices
|
||||
return self._n_gpu
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def parallel_mode(self):
|
||||
"""
|
||||
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
|
||||
@@ -1570,6 +1568,7 @@ class TrainingArguments:
|
||||
`torch.nn.DistributedDataParallel`).
|
||||
- `ParallelMode.TPU`: several TPU cores.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
if is_torch_tpu_available():
|
||||
return ParallelMode.TPU
|
||||
elif is_sagemaker_mp_enabled():
|
||||
@@ -1584,11 +1583,12 @@ class TrainingArguments:
|
||||
return ParallelMode.NOT_PARALLEL
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def world_size(self):
|
||||
"""
|
||||
The number of processes used in parallel.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
if is_torch_tpu_available():
|
||||
return xm.xrt_world_size()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
@@ -1600,11 +1600,11 @@ class TrainingArguments:
|
||||
return 1
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def process_index(self):
|
||||
"""
|
||||
The index of the current process used.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
if is_torch_tpu_available():
|
||||
return xm.get_ordinal()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
@@ -1616,11 +1616,11 @@ class TrainingArguments:
|
||||
return 0
|
||||
|
||||
@property
|
||||
@torch_required
|
||||
def local_process_index(self):
|
||||
"""
|
||||
The index of the local process used.
|
||||
"""
|
||||
requires_backends(self, ["torch"])
|
||||
if is_torch_tpu_available():
|
||||
return xm.get_local_ordinal()
|
||||
elif is_sagemaker_mp_enabled():
|
||||
|
||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from .training_args import TrainingArguments
|
||||
from .utils import cached_property, is_tf_available, logging, tf_required
|
||||
from .utils import cached_property, is_tf_available, logging, requires_backends
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
|
||||
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
|
||||
|
||||
@cached_property
|
||||
@tf_required
|
||||
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
|
||||
requires_backends(self, ["tf"])
|
||||
logger.info("Tensorflow: setting up strategy")
|
||||
|
||||
gpus = tf.config.list_physical_devices("GPU")
|
||||
@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments):
|
||||
return strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def strategy(self) -> "tf.distribute.Strategy":
|
||||
"""
|
||||
The strategy used for distributed training.
|
||||
"""
|
||||
requires_backends(self, ["tf"])
|
||||
return self._setup_strategy
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_replicas(self) -> int:
|
||||
"""
|
||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||
"""
|
||||
requires_backends(self, ["tf"])
|
||||
return self._setup_strategy.num_replicas_in_sync
|
||||
|
||||
@property
|
||||
@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments):
|
||||
return per_device_batch_size * self.n_replicas
|
||||
|
||||
@property
|
||||
@tf_required
|
||||
def n_gpu(self) -> int:
|
||||
"""
|
||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||
"""
|
||||
requires_backends(self, ["tf"])
|
||||
warnings.warn(
|
||||
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
|
||||
FutureWarning,
|
||||
|
||||
@@ -163,9 +163,7 @@ from .import_utils import (
|
||||
is_training_run_on_sagemaker,
|
||||
is_vision_available,
|
||||
requires_backends,
|
||||
tf_required,
|
||||
torch_only_method,
|
||||
torch_required,
|
||||
torch_version,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import shutil
|
||||
import sys
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from functools import lru_cache, wraps
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
@@ -1039,30 +1039,6 @@ class DummyObject(type):
|
||||
requires_backends(cls, cls._backends)
|
||||
|
||||
|
||||
def torch_required(func):
|
||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_torch_available():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def tf_required(func):
|
||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
if is_tf_available():
|
||||
return func(*args, **kwargs)
|
||||
else:
|
||||
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_torch_fx_proxy(x):
|
||||
if is_torch_fx_available():
|
||||
import torch.fx
|
||||
|
||||
Reference in New Issue
Block a user