Replaces xxx_required with requires_backends (#20715)
* Replaces xxx_required with requires_backends * Fixup
This commit is contained in:
@@ -17,7 +17,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, torch_required
|
from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -76,8 +76,8 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@torch_required
|
|
||||||
def _setup_devices(self) -> Tuple["torch.device", int]:
|
def _setup_devices(self) -> Tuple["torch.device", int]:
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
logger.info("PyTorch: setting up devices")
|
logger.info("PyTorch: setting up devices")
|
||||||
if not self.cuda:
|
if not self.cuda:
|
||||||
device = torch.device("cpu")
|
device = torch.device("cpu")
|
||||||
@@ -95,19 +95,19 @@ class PyTorchBenchmarkArguments(BenchmarkArguments):
|
|||||||
return is_torch_tpu_available() and self.tpu
|
return is_torch_tpu_available() and self.tpu
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def device_idx(self) -> int:
|
def device_idx(self) -> int:
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
# TODO(PVP): currently only single GPU is supported
|
# TODO(PVP): currently only single GPU is supported
|
||||||
return torch.cuda.current_device()
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def device(self) -> "torch.device":
|
def device(self) -> "torch.device":
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
return self._setup_devices[0]
|
return self._setup_devices[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def n_gpu(self):
|
def n_gpu(self):
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
return self._setup_devices[1]
|
return self._setup_devices[1]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from ..utils import cached_property, is_tf_available, logging, tf_required
|
from ..utils import cached_property, is_tf_available, logging, requires_backends
|
||||||
from .benchmark_args_utils import BenchmarkArguments
|
from .benchmark_args_utils import BenchmarkArguments
|
||||||
|
|
||||||
|
|
||||||
@@ -77,8 +77,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@tf_required
|
|
||||||
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
def _setup_tpu(self) -> Tuple["tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
tpu = None
|
tpu = None
|
||||||
if self.tpu:
|
if self.tpu:
|
||||||
try:
|
try:
|
||||||
@@ -91,8 +91,8 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
|||||||
return tpu
|
return tpu
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@tf_required
|
|
||||||
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", "tf.distribute.cluster_resolver.TPUClusterResolver"]:
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
if self.is_tpu:
|
if self.is_tpu:
|
||||||
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
tf.config.experimental_connect_to_cluster(self._setup_tpu)
|
||||||
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
tf.tpu.experimental.initialize_tpu_system(self._setup_tpu)
|
||||||
@@ -111,23 +111,23 @@ class TensorFlowBenchmarkArguments(BenchmarkArguments):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def is_tpu(self) -> bool:
|
def is_tpu(self) -> bool:
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
return self._setup_tpu is not None
|
return self._setup_tpu is not None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def strategy(self) -> "tf.distribute.Strategy":
|
def strategy(self) -> "tf.distribute.Strategy":
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
return self._setup_strategy
|
return self._setup_strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def gpu_list(self):
|
def gpu_list(self):
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
return tf.config.list_physical_devices("GPU")
|
return tf.config.list_physical_devices("GPU")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def n_gpu(self) -> int:
|
def n_gpu(self) -> int:
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
if self.cuda:
|
if self.cuda:
|
||||||
return len(self.gpu_list)
|
return len(self.gpu_list)
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ from .utils import (
|
|||||||
is_torch_device,
|
is_torch_device,
|
||||||
is_torch_dtype,
|
is_torch_dtype,
|
||||||
logging,
|
logging,
|
||||||
torch_required,
|
requires_backends,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -175,7 +175,6 @@ class BatchFeature(UserDict):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@torch_required
|
|
||||||
def to(self, *args, **kwargs) -> "BatchFeature":
|
def to(self, *args, **kwargs) -> "BatchFeature":
|
||||||
"""
|
"""
|
||||||
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
Send all values to device by calling `v.to(*args, **kwargs)` (PyTorch only). This should support casting in
|
||||||
@@ -190,6 +189,7 @@ class BatchFeature(UserDict):
|
|||||||
Returns:
|
Returns:
|
||||||
[`BatchFeature`]: The same instance after modification.
|
[`BatchFeature`]: The same instance after modification.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
import torch # noqa
|
import torch # noqa
|
||||||
|
|
||||||
new_data = {}
|
new_data = {}
|
||||||
|
|||||||
@@ -127,10 +127,8 @@ from .utils import (
|
|||||||
is_vision_available,
|
is_vision_available,
|
||||||
replace_return_docstrings,
|
replace_return_docstrings,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
tf_required,
|
|
||||||
to_numpy,
|
to_numpy,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
torch_required,
|
|
||||||
torch_version,
|
torch_version,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ from .utils import (
|
|||||||
is_torch_device,
|
is_torch_device,
|
||||||
is_torch_tensor,
|
is_torch_tensor,
|
||||||
logging,
|
logging,
|
||||||
|
requires_backends,
|
||||||
to_py_obj,
|
to_py_obj,
|
||||||
torch_required,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -739,7 +739,6 @@ class BatchEncoding(UserDict):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@torch_required
|
|
||||||
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
|
||||||
"""
|
"""
|
||||||
Send all values to device by calling `v.to(device)` (PyTorch only).
|
Send all values to device by calling `v.to(device)` (PyTorch only).
|
||||||
@@ -750,6 +749,7 @@ class BatchEncoding(UserDict):
|
|||||||
Returns:
|
Returns:
|
||||||
[`BatchEncoding`]: The same instance after modification.
|
[`BatchEncoding`]: The same instance after modification.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
# This check catches things like APEX blindly calling "to" on all inputs to a module
|
||||||
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
|
||||||
|
|||||||
@@ -50,7 +50,6 @@ from .utils import (
|
|||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
logging,
|
logging,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
torch_required,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -1386,8 +1385,8 @@ class TrainingArguments:
|
|||||||
return timedelta(seconds=self.ddp_timeout)
|
return timedelta(seconds=self.ddp_timeout)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@torch_required
|
|
||||||
def _setup_devices(self) -> "torch.device":
|
def _setup_devices(self) -> "torch.device":
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
logger.info("PyTorch: setting up devices")
|
logger.info("PyTorch: setting up devices")
|
||||||
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
|
if torch.distributed.is_available() and torch.distributed.is_initialized() and self.local_rank == -1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
@@ -1537,15 +1536,14 @@ class TrainingArguments:
|
|||||||
return device
|
return device
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def device(self) -> "torch.device":
|
def device(self) -> "torch.device":
|
||||||
"""
|
"""
|
||||||
The device used by this process.
|
The device used by this process.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
return self._setup_devices
|
return self._setup_devices
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def n_gpu(self):
|
def n_gpu(self):
|
||||||
"""
|
"""
|
||||||
The number of GPUs used by this process.
|
The number of GPUs used by this process.
|
||||||
@@ -1554,12 +1552,12 @@ class TrainingArguments:
|
|||||||
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
This will only be greater than one when you have multiple GPUs available but are not using distributed
|
||||||
training. For distributed training, it will always be 1.
|
training. For distributed training, it will always be 1.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
# Make sure `self._n_gpu` is properly setup.
|
# Make sure `self._n_gpu` is properly setup.
|
||||||
_ = self._setup_devices
|
_ = self._setup_devices
|
||||||
return self._n_gpu
|
return self._n_gpu
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def parallel_mode(self):
|
def parallel_mode(self):
|
||||||
"""
|
"""
|
||||||
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
|
The current mode used for parallelism if multiple GPUs/TPU cores are available. One of:
|
||||||
@@ -1570,6 +1568,7 @@ class TrainingArguments:
|
|||||||
`torch.nn.DistributedDataParallel`).
|
`torch.nn.DistributedDataParallel`).
|
||||||
- `ParallelMode.TPU`: several TPU cores.
|
- `ParallelMode.TPU`: several TPU cores.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return ParallelMode.TPU
|
return ParallelMode.TPU
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
@@ -1584,11 +1583,12 @@ class TrainingArguments:
|
|||||||
return ParallelMode.NOT_PARALLEL
|
return ParallelMode.NOT_PARALLEL
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def world_size(self):
|
def world_size(self):
|
||||||
"""
|
"""
|
||||||
The number of processes used in parallel.
|
The number of processes used in parallel.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
|
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.xrt_world_size()
|
return xm.xrt_world_size()
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
@@ -1600,11 +1600,11 @@ class TrainingArguments:
|
|||||||
return 1
|
return 1
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def process_index(self):
|
def process_index(self):
|
||||||
"""
|
"""
|
||||||
The index of the current process used.
|
The index of the current process used.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.get_ordinal()
|
return xm.get_ordinal()
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
@@ -1616,11 +1616,11 @@ class TrainingArguments:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@torch_required
|
|
||||||
def local_process_index(self):
|
def local_process_index(self):
|
||||||
"""
|
"""
|
||||||
The index of the local process used.
|
The index of the local process used.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["torch"])
|
||||||
if is_torch_tpu_available():
|
if is_torch_tpu_available():
|
||||||
return xm.get_local_ordinal()
|
return xm.get_local_ordinal()
|
||||||
elif is_sagemaker_mp_enabled():
|
elif is_sagemaker_mp_enabled():
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from dataclasses import dataclass, field
|
|||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
from .utils import cached_property, is_tf_available, logging, tf_required
|
from .utils import cached_property, is_tf_available, logging, requires_backends
|
||||||
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
@@ -185,8 +185,8 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
|
xla: bool = field(default=False, metadata={"help": "Whether to activate the XLA compilation or not"})
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@tf_required
|
|
||||||
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
|
def _setup_strategy(self) -> Tuple["tf.distribute.Strategy", int]:
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
logger.info("Tensorflow: setting up strategy")
|
logger.info("Tensorflow: setting up strategy")
|
||||||
|
|
||||||
gpus = tf.config.list_physical_devices("GPU")
|
gpus = tf.config.list_physical_devices("GPU")
|
||||||
@@ -234,19 +234,19 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def strategy(self) -> "tf.distribute.Strategy":
|
def strategy(self) -> "tf.distribute.Strategy":
|
||||||
"""
|
"""
|
||||||
The strategy used for distributed training.
|
The strategy used for distributed training.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
return self._setup_strategy
|
return self._setup_strategy
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def n_replicas(self) -> int:
|
def n_replicas(self) -> int:
|
||||||
"""
|
"""
|
||||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
return self._setup_strategy.num_replicas_in_sync
|
return self._setup_strategy.num_replicas_in_sync
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -276,11 +276,11 @@ class TFTrainingArguments(TrainingArguments):
|
|||||||
return per_device_batch_size * self.n_replicas
|
return per_device_batch_size * self.n_replicas
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@tf_required
|
|
||||||
def n_gpu(self) -> int:
|
def n_gpu(self) -> int:
|
||||||
"""
|
"""
|
||||||
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
The number of replicas (CPUs, GPUs or TPU cores) used in this training.
|
||||||
"""
|
"""
|
||||||
|
requires_backends(self, ["tf"])
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
|
"The n_gpu argument is deprecated and will be removed in a future version, use n_replicas instead.",
|
||||||
FutureWarning,
|
FutureWarning,
|
||||||
|
|||||||
@@ -163,9 +163,7 @@ from .import_utils import (
|
|||||||
is_training_run_on_sagemaker,
|
is_training_run_on_sagemaker,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
requires_backends,
|
requires_backends,
|
||||||
tf_required,
|
|
||||||
torch_only_method,
|
torch_only_method,
|
||||||
torch_required,
|
|
||||||
torch_version,
|
torch_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -1039,30 +1039,6 @@ class DummyObject(type):
|
|||||||
requires_backends(cls, cls._backends)
|
requires_backends(cls, cls._backends)
|
||||||
|
|
||||||
|
|
||||||
def torch_required(func):
|
|
||||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
if is_torch_available():
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ImportError(f"Method `{func.__name__}` requires PyTorch.")
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def tf_required(func):
|
|
||||||
# Chose a different decorator name than in tests so it's clear they are not the same.
|
|
||||||
@wraps(func)
|
|
||||||
def wrapper(*args, **kwargs):
|
|
||||||
if is_tf_available():
|
|
||||||
return func(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
raise ImportError(f"Method `{func.__name__}` requires TF.")
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def is_torch_fx_proxy(x):
|
def is_torch_fx_proxy(x):
|
||||||
if is_torch_fx_available():
|
if is_torch_fx_available():
|
||||||
import torch.fx
|
import torch.fx
|
||||||
|
|||||||
Reference in New Issue
Block a user