[trainer] add tf32-mode control (#14606)
* [trainer] add --tf32 support * it's pt>=.17 * it's pt>=.17 * flip the default to True * add experimental note * simplify logic * style * switch to 3-state logic * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * re-style code Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -358,8 +358,13 @@ Like all cases with reduced precision this may or may not be satisfactory for yo
|
||||
|
||||
If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.
|
||||
|
||||
You can enable this mode in the 🤗 Trainer with `--tf32`, or disable it with `--tf32 0` or `--no_tf32`.
|
||||
By default the PyTorch default is used.
|
||||
|
||||
Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.
|
||||
|
||||
Note: you need `torch>=1.7` to enjoy this feature.
|
||||
|
||||
|
||||
### Gradient Checkpointing
|
||||
|
||||
|
||||
@@ -321,35 +321,53 @@ def is_torch_cuda_available():
|
||||
|
||||
|
||||
def is_torch_bf16_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
# since currently no utility function is available we build our own.
|
||||
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
|
||||
# with additional check for torch version
|
||||
# to succeed:
|
||||
# 1. the hardware needs to support bf16 (arch >= Ampere)
|
||||
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
|
||||
# 3. CUDA >= 11
|
||||
# 4. torch.autocast exists
|
||||
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
|
||||
# really only correct for the 0th gpu (or currently set default device if different from 0)
|
||||
|
||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||
return False
|
||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||
return False
|
||||
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||
return False
|
||||
if not version.parse(torch.__version__) >= version.parse("1.10"):
|
||||
return False
|
||||
if not hasattr(torch, "autocast"):
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
# since currently no utility function is available we build our own.
|
||||
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
|
||||
# with additional check for torch version
|
||||
# to succeed:
|
||||
# 1. the hardware needs to support bf16 (arch >= Ampere)
|
||||
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
|
||||
# 3. CUDA >= 11
|
||||
# 4. torch.autocast exists
|
||||
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
|
||||
# really only correct for the 0th gpu (or currently set default device if different from 0)
|
||||
|
||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||
return False
|
||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||
return False
|
||||
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||
return False
|
||||
if version.parse(torch.__version__) < version.parse("1.10"):
|
||||
return False
|
||||
if not hasattr(torch, "autocast"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def is_torch_tf32_available():
|
||||
if not is_torch_available():
|
||||
return False
|
||||
|
||||
import torch
|
||||
|
||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||
return False
|
||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||
return False
|
||||
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||
return False
|
||||
if version.parse(torch.__version__) < version.parse("1.7"):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||
if _torch_available:
|
||||
|
||||
@@ -50,6 +50,7 @@ from .file_utils import (
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_vision_available,
|
||||
@@ -495,9 +496,17 @@ def require_torch_gpu(test_case):
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
||||
if not is_torch_bf16_available():
|
||||
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
|
||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_tf32(test_case):
|
||||
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
||||
if not is_torch_tf32_available():
|
||||
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from .file_utils import (
|
||||
is_sagemaker_dp_enabled,
|
||||
is_sagemaker_mp_enabled,
|
||||
is_torch_available,
|
||||
is_torch_tf32_available,
|
||||
is_torch_tpu_available,
|
||||
torch_required,
|
||||
)
|
||||
@@ -227,6 +228,9 @@ class TrainingArguments:
|
||||
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||
metric values.
|
||||
tf32 (:obj:`bool`, `optional`):
|
||||
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
|
||||
and it may change.
|
||||
local_rank (:obj:`int`, `optional`, defaults to -1):
|
||||
Rank of the process during distributed training.
|
||||
xpu_backend (:obj:`str`, `optional`):
|
||||
@@ -548,6 +552,12 @@ class TrainingArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
|
||||
)
|
||||
tf32: bool = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
|
||||
},
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
xpu_backend: str = field(
|
||||
default=None,
|
||||
@@ -802,6 +812,17 @@ class TrainingArguments:
|
||||
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
|
||||
)
|
||||
|
||||
if is_torch_available() and self.tf32 is not None:
|
||||
if self.tf32:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
else:
|
||||
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||
else:
|
||||
if is_torch_tf32_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
# no need to assert on else
|
||||
|
||||
if self.report_to is None:
|
||||
logger.info(
|
||||
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
||||
|
||||
@@ -57,6 +57,7 @@ from transformers.testing_utils import (
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
require_torch_tf32,
|
||||
require_torch_up_to_2_gpus,
|
||||
slow,
|
||||
)
|
||||
@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
|
||||
# will add more specific tests once there are some bugs to fix
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_tf32
|
||||
def test_tf32(self):
|
||||
|
||||
# very basic test
|
||||
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
|
||||
Reference in New Issue
Block a user