[trainer] add tf32-mode control (#14606)
* [trainer] add --tf32 support * it's pt>=.17 * it's pt>=.17 * flip the default to True * add experimental note * simplify logic * style * switch to 3-state logic * doc * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * re-style code Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -358,8 +358,13 @@ Like all cases with reduced precision this may or may not be satisfactory for yo
|
|||||||
|
|
||||||
If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.
|
If you're already using fp16 or bf16 mixed precision it may help with the throughput as well.
|
||||||
|
|
||||||
|
You can enable this mode in the 🤗 Trainer with `--tf32`, or disable it with `--tf32 0` or `--no_tf32`.
|
||||||
|
By default the PyTorch default is used.
|
||||||
|
|
||||||
Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.
|
Note: tf32 mode is internal to CUDA and can't be accessed directly via `tensor.to(dtype=torch.tf32)` as `torch.tf32` doesn't exit.
|
||||||
|
|
||||||
|
Note: you need `torch>=1.7` to enjoy this feature.
|
||||||
|
|
||||||
|
|
||||||
### Gradient Checkpointing
|
### Gradient Checkpointing
|
||||||
|
|
||||||
|
|||||||
@@ -321,35 +321,53 @@ def is_torch_cuda_available():
|
|||||||
|
|
||||||
|
|
||||||
def is_torch_bf16_available():
|
def is_torch_bf16_available():
|
||||||
if is_torch_available():
|
if not is_torch_available():
|
||||||
import torch
|
|
||||||
|
|
||||||
# since currently no utility function is available we build our own.
|
|
||||||
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
|
|
||||||
# with additional check for torch version
|
|
||||||
# to succeed:
|
|
||||||
# 1. the hardware needs to support bf16 (arch >= Ampere)
|
|
||||||
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
|
|
||||||
# 3. CUDA >= 11
|
|
||||||
# 4. torch.autocast exists
|
|
||||||
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
|
|
||||||
# really only correct for the 0th gpu (or currently set default device if different from 0)
|
|
||||||
|
|
||||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
|
||||||
return False
|
|
||||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
|
||||||
return False
|
|
||||||
if int(torch.version.cuda.split(".")[0]) < 11:
|
|
||||||
return False
|
|
||||||
if not version.parse(torch.__version__) >= version.parse("1.10"):
|
|
||||||
return False
|
|
||||||
if not hasattr(torch, "autocast"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# since currently no utility function is available we build our own.
|
||||||
|
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
|
||||||
|
# with additional check for torch version
|
||||||
|
# to succeed:
|
||||||
|
# 1. the hardware needs to support bf16 (arch >= Ampere)
|
||||||
|
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
|
||||||
|
# 3. CUDA >= 11
|
||||||
|
# 4. torch.autocast exists
|
||||||
|
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
|
||||||
|
# really only correct for the 0th gpu (or currently set default device if different from 0)
|
||||||
|
|
||||||
|
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||||
|
return False
|
||||||
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||||
|
return False
|
||||||
|
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||||
|
return False
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.10"):
|
||||||
|
return False
|
||||||
|
if not hasattr(torch, "autocast"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def is_torch_tf32_available():
|
||||||
|
if not is_torch_available():
|
||||||
|
return False
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||||
|
return False
|
||||||
|
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||||
|
return False
|
||||||
|
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||||
|
return False
|
||||||
|
if version.parse(torch.__version__) < version.parse("1.7"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||||
if _torch_available:
|
if _torch_available:
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ from .file_utils import (
|
|||||||
is_tokenizers_available,
|
is_tokenizers_available,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
is_torch_bf16_available,
|
is_torch_bf16_available,
|
||||||
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
is_torchaudio_available,
|
is_torchaudio_available,
|
||||||
is_vision_available,
|
is_vision_available,
|
||||||
@@ -495,9 +496,17 @@ def require_torch_gpu(test_case):
|
|||||||
|
|
||||||
|
|
||||||
def require_torch_bf16(test_case):
|
def require_torch_bf16(test_case):
|
||||||
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
|
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10."""
|
||||||
if not is_torch_bf16_available():
|
if not is_torch_bf16_available():
|
||||||
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
|
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.10")(test_case)
|
||||||
|
else:
|
||||||
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
|
def require_torch_tf32(test_case):
|
||||||
|
"""Decorator marking a test that requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7."""
|
||||||
|
if not is_torch_tf32_available():
|
||||||
|
return unittest.skip("test requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")(test_case)
|
||||||
else:
|
else:
|
||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from .file_utils import (
|
|||||||
is_sagemaker_dp_enabled,
|
is_sagemaker_dp_enabled,
|
||||||
is_sagemaker_mp_enabled,
|
is_sagemaker_mp_enabled,
|
||||||
is_torch_available,
|
is_torch_available,
|
||||||
|
is_torch_tf32_available,
|
||||||
is_torch_tpu_available,
|
is_torch_tpu_available,
|
||||||
torch_required,
|
torch_required,
|
||||||
)
|
)
|
||||||
@@ -227,6 +228,9 @@ class TrainingArguments:
|
|||||||
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||||
metric values.
|
metric values.
|
||||||
|
tf32 (:obj:`bool`, `optional`):
|
||||||
|
Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API
|
||||||
|
and it may change.
|
||||||
local_rank (:obj:`int`, `optional`, defaults to -1):
|
local_rank (:obj:`int`, `optional`, defaults to -1):
|
||||||
Rank of the process during distributed training.
|
Rank of the process during distributed training.
|
||||||
xpu_backend (:obj:`str`, `optional`):
|
xpu_backend (:obj:`str`, `optional`):
|
||||||
@@ -548,6 +552,12 @@ class TrainingArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
|
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
|
||||||
)
|
)
|
||||||
|
tf32: bool = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Whether to enable tf32 mode, available in Ampere and newer GPU architectures. This is an experimental API and it may change."
|
||||||
|
},
|
||||||
|
)
|
||||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||||
xpu_backend: str = field(
|
xpu_backend: str = field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -802,6 +812,17 @@ class TrainingArguments:
|
|||||||
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
|
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if is_torch_available() and self.tf32 is not None:
|
||||||
|
if self.tf32:
|
||||||
|
if is_torch_tf32_available():
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
else:
|
||||||
|
raise ValueError("--tf32 requires Ampere or a newer GPU arch, cuda>=11 and torch>=1.7")
|
||||||
|
else:
|
||||||
|
if is_torch_tf32_available():
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = False
|
||||||
|
# no need to assert on else
|
||||||
|
|
||||||
if self.report_to is None:
|
if self.report_to is None:
|
||||||
logger.info(
|
logger.info(
|
||||||
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
||||||
|
|||||||
@@ -57,6 +57,7 @@ from transformers.testing_utils import (
|
|||||||
require_torch_gpu,
|
require_torch_gpu,
|
||||||
require_torch_multi_gpu,
|
require_torch_multi_gpu,
|
||||||
require_torch_non_multi_gpu,
|
require_torch_non_multi_gpu,
|
||||||
|
require_torch_tf32,
|
||||||
require_torch_up_to_2_gpus,
|
require_torch_up_to_2_gpus,
|
||||||
slow,
|
slow,
|
||||||
)
|
)
|
||||||
@@ -492,6 +493,15 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
|
|
||||||
# will add more specific tests once there are some bugs to fix
|
# will add more specific tests once there are some bugs to fix
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torch_tf32
|
||||||
|
def test_tf32(self):
|
||||||
|
|
||||||
|
# very basic test
|
||||||
|
trainer = get_regression_trainer(learning_rate=0.1, tf32=True)
|
||||||
|
trainer.train()
|
||||||
|
self.check_trained_model(trainer.model)
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_sentencepiece
|
@require_sentencepiece
|
||||||
|
|||||||
Reference in New Issue
Block a user