WIP: Support for Training with BF16 (#13207)
* started bf16 integration * minor changes * code now runs * style * lay foundation for bf16 testing * lay foundation for bf16 testing * start the tests * better bf16 check * style * 2 separate checkers - one for bf16 support, another for bf16+autocast * Update src/transformers/training_args.py Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> * a couple of comment resolutions * more comment resolutions * resolved a small bug * just some print statemtns * added todo marking * added a todo * adjust for API change s/fast_dtype/dtype/ * fix style * merge 2 bf16 util functions * bf16 now does scaling too * Add support for bfloat16 * Revert T5 layernorm to float32 This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920 . * Add comment about conversion to float32 before returning the numpy data * Add comment about AMP-bfloat16 incompatibility * Fix formatting * typo * reformer / bf16 * cleanup * require at least pt-1.10 * fix * will deal with deepspeed separately * cleanup * revert * cleanup * fp16_full_eval and bf16_full_eval are separate modes * proper deprecation * cleanup * test and fixes * spelling * cleanup * add a note that this API is experimental Co-authored-by: jamie <jamie@cortx.com> Co-authored-by: Stas Bekman <stas@stason.org> Co-authored-by: Stas Bekman <stas00@users.noreply.github.com> Co-authored-by: suriya <suriya@cortx.com> Co-authored-by: Manuel R. Ciosici <manuelrciosici@gmail.com>
This commit is contained in:
@@ -320,6 +320,37 @@ def is_torch_cuda_available():
|
||||
return False
|
||||
|
||||
|
||||
def is_torch_bf16_available():
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
# since currently no utility function is available we build our own.
|
||||
# some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51
|
||||
# with additional check for torch version
|
||||
# to succeed:
|
||||
# 1. the hardware needs to support bf16 (arch >= Ampere)
|
||||
# 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal)
|
||||
# 3. CUDA >= 11
|
||||
# 4. torch.autocast exists
|
||||
# XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's
|
||||
# really only correct for the 0th gpu (or currently set default device if different from 0)
|
||||
|
||||
if not torch.cuda.is_available() or torch.version.cuda is None:
|
||||
return False
|
||||
if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8:
|
||||
return False
|
||||
if int(torch.version.cuda.split(".")[0]) < 11:
|
||||
return False
|
||||
if not version.parse(torch.__version__) >= version.parse("1.10"):
|
||||
return False
|
||||
if not hasattr(torch, "autocast"):
|
||||
return False
|
||||
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
_torch_fx_available = _torch_onnx_dict_inputs_support_available = False
|
||||
if _torch_available:
|
||||
torch_version = version.parse(importlib_metadata.version("torch"))
|
||||
|
||||
@@ -233,7 +233,7 @@ class ModuleUtilsMixin:
|
||||
|
||||
if self.dtype == torch.float16:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4
|
||||
elif self.dtype == torch.float32:
|
||||
elif self.dtype in [torch.bfloat16, torch.float32]:
|
||||
encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9
|
||||
else:
|
||||
raise ValueError(
|
||||
|
||||
@@ -242,9 +242,10 @@ class T5LayerNorm(nn.Module):
|
||||
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
|
||||
# convert into float16 if necessary
|
||||
if self.weight.dtype == torch.float16:
|
||||
hidden_states = hidden_states.to(torch.float16)
|
||||
# convert into half-precision if necessary
|
||||
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
||||
hidden_states = hidden_states.to(self.weight.dtype)
|
||||
|
||||
return self.weight * hidden_states
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from .file_utils import (
|
||||
is_timm_available,
|
||||
is_tokenizers_available,
|
||||
is_torch_available,
|
||||
is_torch_bf16_available,
|
||||
is_torch_tpu_available,
|
||||
is_torchaudio_available,
|
||||
is_vision_available,
|
||||
@@ -493,6 +494,14 @@ def require_torch_gpu(test_case):
|
||||
return test_case
|
||||
|
||||
|
||||
def require_torch_bf16(test_case):
|
||||
"""Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10."""
|
||||
if not is_torch_bf16_available():
|
||||
return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case)
|
||||
else:
|
||||
return test_case
|
||||
|
||||
|
||||
def require_datasets(test_case):
|
||||
"""Decorator marking a test that requires datasets."""
|
||||
|
||||
|
||||
@@ -353,13 +353,13 @@ class Trainer:
|
||||
# 1. MP - since we are trying to fit a much bigger than 1 gpu model
|
||||
# 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway,
|
||||
# and we only use deepspeed for training at the moment
|
||||
# 3. full fp16 eval - since the model needs to be half'ed first
|
||||
# 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first
|
||||
# 4. Sharded DDP - same as MP
|
||||
self.place_model_on_device = args.place_model_on_device
|
||||
if (
|
||||
self.is_model_parallel
|
||||
or args.deepspeed
|
||||
or (args.fp16_full_eval and not args.do_train)
|
||||
or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train)
|
||||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
|
||||
):
|
||||
self.place_model_on_device = False
|
||||
@@ -424,18 +424,24 @@ class Trainer:
|
||||
# Mixed precision setup
|
||||
self.use_apex = False
|
||||
self.use_amp = False
|
||||
self.fp16_backend = None
|
||||
|
||||
if args.fp16:
|
||||
if args.fp16_backend == "auto":
|
||||
self.fp16_backend = "amp" if _is_native_amp_available else "apex"
|
||||
if args.fp16 or args.bf16:
|
||||
if args.half_precision_backend == "auto":
|
||||
if _is_native_amp_available:
|
||||
args.half_precision_backend = "amp"
|
||||
else:
|
||||
self.fp16_backend = args.fp16_backend
|
||||
logger.info(f"Using {self.fp16_backend} fp16 backend")
|
||||
if args.bf16:
|
||||
raise ValueError("Tried to use `bf16` but native amp is not available")
|
||||
else:
|
||||
args.half_precision_backend = "apex"
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
||||
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
|
||||
if self.fp16_backend == "amp":
|
||||
self.do_grad_scaling = False
|
||||
if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision
|
||||
if args.half_precision_backend == "amp":
|
||||
self.use_amp = True
|
||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||
self.do_grad_scaling = True
|
||||
if is_sagemaker_mp_enabled():
|
||||
self.scaler = smp.amp.GradScaler()
|
||||
elif self.sharded_ddp is not None:
|
||||
@@ -975,7 +981,7 @@ class Trainer:
|
||||
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
|
||||
model = ShardedDDP(model, self.optimizer)
|
||||
else:
|
||||
mixed_precision = self.args.fp16
|
||||
mixed_precision = self.args.fp16 or self.args.bf16
|
||||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
|
||||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
|
||||
# XXX: Breaking the self.model convention but I see no way around it for now.
|
||||
@@ -1043,7 +1049,7 @@ class Trainer:
|
||||
|
||||
# do_train is not a reliable argument, as it might not be set and .train() still called, so
|
||||
# the following is a workaround:
|
||||
if args.fp16_full_eval and not args.do_train:
|
||||
if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train:
|
||||
self._move_model_to_device(self.model, args.device)
|
||||
|
||||
if "model_path" in kwargs:
|
||||
@@ -1341,7 +1347,7 @@ class Trainer:
|
||||
if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed:
|
||||
# deepspeed does its own clipping
|
||||
|
||||
if self.use_amp:
|
||||
if self.do_grad_scaling:
|
||||
# AMP: gradients need unscaling
|
||||
self.scaler.unscale_(self.optimizer)
|
||||
|
||||
@@ -1364,7 +1370,7 @@ class Trainer:
|
||||
pass # called outside the loop
|
||||
elif is_torch_tpu_available():
|
||||
xm.optimizer_step(self.optimizer)
|
||||
elif self.use_amp:
|
||||
elif self.do_grad_scaling:
|
||||
scale_before = self.scaler.get_scale()
|
||||
self.scaler.step(self.optimizer)
|
||||
self.scaler.update()
|
||||
@@ -1588,7 +1594,7 @@ class Trainer:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp:
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
elif self.args.should_save and not self.deepspeed:
|
||||
# deepspeed.save_checkpoint above saves model/optim/sched
|
||||
@@ -1596,7 +1602,7 @@ class Trainer:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp:
|
||||
if self.do_grad_scaling:
|
||||
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
|
||||
|
||||
# Determine the new best metric / best model checkpoint
|
||||
@@ -1684,7 +1690,7 @@ class Trainer:
|
||||
with warnings.catch_warnings(record=True) as caught_warnings:
|
||||
self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME)))
|
||||
reissue_pt_warnings(caught_warnings)
|
||||
if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||
if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)):
|
||||
self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME)))
|
||||
|
||||
def hyperparameter_search(
|
||||
@@ -1846,12 +1852,12 @@ class Trainer:
|
||||
inputs = self._prepare_inputs(inputs)
|
||||
|
||||
if is_sagemaker_mp_enabled():
|
||||
scaler = self.scaler if self.use_amp else None
|
||||
scaler = self.scaler if self.do_grad_scaling else None
|
||||
loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler)
|
||||
return loss_mb.reduce_mean().detach().to(self.args.device)
|
||||
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
with autocast(dtype=self.amp_dtype):
|
||||
loss = self.compute_loss(model, inputs)
|
||||
else:
|
||||
loss = self.compute_loss(model, inputs)
|
||||
@@ -1863,7 +1869,7 @@ class Trainer:
|
||||
# deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
|
||||
loss = loss / self.args.gradient_accumulation_steps
|
||||
|
||||
if self.use_amp:
|
||||
if self.do_grad_scaling:
|
||||
self.scaler.scale(loss).backward()
|
||||
elif self.use_apex:
|
||||
with amp.scale_loss(loss, self.optimizer) as scaled_loss:
|
||||
@@ -2220,12 +2226,12 @@ class Trainer:
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
prediction_loss_only = (
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
args = self.args
|
||||
|
||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
|
||||
|
||||
# if eval is called w/o train init deepspeed here
|
||||
if self.args.deepspeed and not self.deepspeed:
|
||||
if args.deepspeed and not self.deepspeed:
|
||||
|
||||
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
|
||||
# from the checkpoint eventually
|
||||
@@ -2238,10 +2244,13 @@ class Trainer:
|
||||
|
||||
model = self._wrap_model(self.model, training=False)
|
||||
|
||||
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
|
||||
# ``train`` is running, halve it first and then put on device
|
||||
if not self.is_in_train and self.args.fp16_full_eval:
|
||||
model = model.half().to(self.args.device)
|
||||
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
||||
# while ``train`` is running, cast it to the right dtype first and then put on device
|
||||
if not self.is_in_train:
|
||||
if args.fp16_full_eval:
|
||||
model = model.to(dtype=torch.float16, device=args.device)
|
||||
elif args.bf16_full_eval:
|
||||
model = model.to(dtype=torch.bfloat16, device=args.device)
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
|
||||
@@ -2259,9 +2268,9 @@ class Trainer:
|
||||
eval_dataset = dataloader.dataset
|
||||
|
||||
if is_torch_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
# Initialize containers
|
||||
@@ -2301,10 +2310,10 @@ class Trainer:
|
||||
labels = self._pad_across_processes(labels)
|
||||
labels = self._nested_gather(labels)
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
if losses_host is not None:
|
||||
losses = nested_numpify(losses_host)
|
||||
all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0)
|
||||
@@ -2320,7 +2329,7 @@ class Trainer:
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
@@ -2492,11 +2501,12 @@ class Trainer:
|
||||
else:
|
||||
if has_labels:
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
with autocast(dtype=self.amp_dtype):
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
else:
|
||||
loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
|
||||
loss = loss.mean().detach()
|
||||
|
||||
if isinstance(outputs, dict):
|
||||
logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"])
|
||||
else:
|
||||
@@ -2504,7 +2514,7 @@ class Trainer:
|
||||
else:
|
||||
loss = None
|
||||
if self.use_amp:
|
||||
with autocast():
|
||||
with autocast(dtype=self.amp_dtype):
|
||||
outputs = model(**inputs)
|
||||
else:
|
||||
outputs = model(**inputs)
|
||||
@@ -2719,14 +2729,14 @@ class Trainer:
|
||||
|
||||
Works both with or without labels.
|
||||
"""
|
||||
args = self.args
|
||||
|
||||
if not isinstance(dataloader.dataset, collections.abc.Sized):
|
||||
raise ValueError("dataset must implement __len__")
|
||||
prediction_loss_only = (
|
||||
prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only
|
||||
)
|
||||
prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only
|
||||
|
||||
# if eval is called w/o train init deepspeed here
|
||||
if self.args.deepspeed and not self.deepspeed:
|
||||
if args.deepspeed and not self.deepspeed:
|
||||
|
||||
# XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval
|
||||
# from the checkpoint eventually
|
||||
@@ -2742,10 +2752,13 @@ class Trainer:
|
||||
|
||||
model = self._wrap_model(self.model, training=False)
|
||||
|
||||
# if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
|
||||
# ``train`` is running, halve it first and then put on device
|
||||
if not self.is_in_train and self.args.fp16_full_eval:
|
||||
model = model.half().to(self.args.device)
|
||||
# if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called
|
||||
# while ``train`` is running, cast it to the right dtype first and then put on device
|
||||
if not self.is_in_train:
|
||||
if args.fp16_full_eval:
|
||||
model = model.to(dtype=torch.float16, device=args.device)
|
||||
elif args.bf16_full_eval:
|
||||
model = model.to(dtype=torch.bfloat16, device=args.device)
|
||||
|
||||
batch_size = dataloader.batch_size
|
||||
num_examples = self.num_examples(dataloader)
|
||||
@@ -2756,7 +2769,7 @@ class Trainer:
|
||||
preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
|
||||
|
||||
world_size = max(1, self.args.world_size)
|
||||
world_size = max(1, args.world_size)
|
||||
|
||||
eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size)
|
||||
if not prediction_loss_only:
|
||||
@@ -2771,9 +2784,9 @@ class Trainer:
|
||||
model.eval()
|
||||
|
||||
if is_torch_tpu_available():
|
||||
dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device)
|
||||
dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device)
|
||||
|
||||
if self.args.past_index >= 0:
|
||||
if args.past_index >= 0:
|
||||
self._past = None
|
||||
|
||||
self.callback_handler.eval_dataloader = dataloader
|
||||
@@ -2787,10 +2800,10 @@ class Trainer:
|
||||
preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100)
|
||||
if labels is not None:
|
||||
labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100)
|
||||
self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control)
|
||||
self.control = self.callback_handler.on_prediction_step(args, self.state, self.control)
|
||||
|
||||
# Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
|
||||
if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0:
|
||||
if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0:
|
||||
eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses"))
|
||||
if not prediction_loss_only:
|
||||
preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds"))
|
||||
@@ -2799,7 +2812,7 @@ class Trainer:
|
||||
# Set back to None to begin a new accumulation
|
||||
losses_host, preds_host, labels_host = None, None, None
|
||||
|
||||
if self.args.past_index and hasattr(self, "_past"):
|
||||
if args.past_index and hasattr(self, "_past"):
|
||||
# Clean the state at the end of the evaluation loop
|
||||
delattr(self, "_past")
|
||||
|
||||
|
||||
@@ -136,7 +136,13 @@ def nested_numpify(tensors):
|
||||
"Numpify `tensors` (even if it's a nested list/tuple of tensors)."
|
||||
if isinstance(tensors, (list, tuple)):
|
||||
return type(tensors)(nested_numpify(t) for t in tensors)
|
||||
return tensors.cpu().numpy()
|
||||
t = tensors.cpu()
|
||||
if t.dtype == torch.bfloat16:
|
||||
# As of Numpy 1.21.4, NumPy does not support bfloat16 (see
|
||||
# https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ).
|
||||
# Until Numpy adds bfloat16, we must convert float32.
|
||||
t = t.to(torch.float32)
|
||||
return t.numpy()
|
||||
|
||||
|
||||
def nested_detach(tensors):
|
||||
|
||||
@@ -207,18 +207,26 @@ class TrainingArguments:
|
||||
Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the
|
||||
:func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly
|
||||
initialized parameters.
|
||||
bf16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher
|
||||
NVIDIA architecture. This is an experimental API and it may change.
|
||||
fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use 16-bit (mixed) precision training instead of 32-bit training.
|
||||
Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training.
|
||||
fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'):
|
||||
For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details
|
||||
on the `Apex documentation <https://nvidia.github.io/apex/amp.html>`__.
|
||||
fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
|
||||
This argument is deprecated. Use ``half_precision_backend`` instead.
|
||||
half_precision_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`):
|
||||
The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or
|
||||
:obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the
|
||||
other choices will force the requested backend.
|
||||
bf16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||
metric values. This is an experimental API and it may change.
|
||||
fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but
|
||||
can harm metric values.
|
||||
Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm
|
||||
metric values.
|
||||
local_rank (:obj:`int`, `optional`, defaults to -1):
|
||||
Rank of the process during distributed training.
|
||||
xpu_backend (:obj:`str`, `optional`):
|
||||
@@ -507,10 +515,15 @@ class TrainingArguments:
|
||||
)
|
||||
no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"})
|
||||
seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."})
|
||||
|
||||
bf16: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change."
|
||||
},
|
||||
)
|
||||
fp16: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"},
|
||||
metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"},
|
||||
)
|
||||
fp16_opt_level: str = field(
|
||||
default="O1",
|
||||
@@ -521,13 +534,19 @@ class TrainingArguments:
|
||||
)
|
||||
},
|
||||
)
|
||||
fp16_backend: str = field(
|
||||
half_precision_backend: str = field(
|
||||
default="auto",
|
||||
metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]},
|
||||
metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]},
|
||||
)
|
||||
bf16_full_eval: bool = field(
|
||||
default=False,
|
||||
metadata={
|
||||
"help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change."
|
||||
},
|
||||
)
|
||||
fp16_full_eval: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"},
|
||||
metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"},
|
||||
)
|
||||
local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"})
|
||||
xpu_backend: str = field(
|
||||
@@ -666,6 +685,10 @@ class TrainingArguments:
|
||||
},
|
||||
)
|
||||
# Deprecated arguments
|
||||
fp16_backend: str = field(
|
||||
default="auto",
|
||||
metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]},
|
||||
)
|
||||
push_to_hub_model_id: str = field(
|
||||
default=None, metadata={"help": "The name of the repository to which push the `Trainer`."}
|
||||
)
|
||||
@@ -754,10 +777,31 @@ class TrainingArguments:
|
||||
if self.run_name is None:
|
||||
self.run_name = self.output_dir
|
||||
|
||||
if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval):
|
||||
raise ValueError(
|
||||
"Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices."
|
||||
if self.fp16_backend and self.fp16_backend != "auto":
|
||||
warnings.warn(
|
||||
"`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `half_precision_backend` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
self.half_precision_backend = self.fp16_backend
|
||||
|
||||
if self.fp16 and self.bf16:
|
||||
raise ValueError("At most one of fp16 and bf16 can be True, but not both")
|
||||
if self.bf16:
|
||||
if self.half_precision_backend == "apex":
|
||||
raise ValueError(
|
||||
" `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead"
|
||||
)
|
||||
if not (self.sharded_ddp == "" or not self.sharded_ddp):
|
||||
raise ValueError("sharded_ddp is not supported with bf16")
|
||||
if (
|
||||
is_torch_available()
|
||||
and self.device.type != "cuda"
|
||||
and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval)
|
||||
):
|
||||
raise ValueError(
|
||||
"Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices."
|
||||
)
|
||||
|
||||
if self.report_to is None:
|
||||
logger.info(
|
||||
"The default value for the training argument `--report_to` will change in v5 (from all installed "
|
||||
|
||||
@@ -53,6 +53,7 @@ from transformers.testing_utils import (
|
||||
require_sigopt,
|
||||
require_tokenizers,
|
||||
require_torch,
|
||||
require_torch_bf16,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
require_torch_non_multi_gpu,
|
||||
@@ -476,6 +477,21 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
self.assertFalse(torch.allclose(trainer.model.b, b))
|
||||
self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_bf16
|
||||
def test_mixed_bf16(self):
|
||||
|
||||
# very basic test
|
||||
trainer = get_regression_trainer(learning_rate=0.1, bf16=True)
|
||||
trainer.train()
|
||||
self.check_trained_model(trainer.model)
|
||||
|
||||
# --bf16 --half_precision_backend apex can't be used together
|
||||
with self.assertRaises(ValueError):
|
||||
trainer = get_regression_trainer(learning_rate=0.1, bf16=True, half_precision_backend="apex")
|
||||
|
||||
# will add more specific tests once there are some bugs to fix
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_sentencepiece
|
||||
@@ -1323,6 +1339,66 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon):
|
||||
# perfect world: fp32_init/2 == fp16_eval
|
||||
self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torch_bf16
|
||||
def test_bf16_full_eval(self):
|
||||
# note: most of the logic is the same as test_fp16_full_eval
|
||||
|
||||
# this is a sensitive test so let's keep debugging printouts in place for quick diagnosis.
|
||||
# it's using pretty large safety margins, but small enough to detect broken functionality.
|
||||
debug = 0
|
||||
n_gpus = get_gpu_count()
|
||||
|
||||
bs = 8
|
||||
eval_len = 16 * n_gpus
|
||||
# make the params somewhat big so that there will be enough RAM consumed to be able to
|
||||
# measure things. We should get about 64KB for a+b in fp32
|
||||
a = torch.ones(1000, bs) + 0.001
|
||||
b = torch.ones(1000, bs) - 0.001
|
||||
|
||||
# 1. with mem metrics enabled
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False)
|
||||
metrics = trainer.evaluate()
|
||||
del trainer
|
||||
gc.collect()
|
||||
|
||||
fp32_init = metrics["init_mem_gpu_alloc_delta"]
|
||||
fp32_eval = metrics["eval_mem_gpu_alloc_delta"]
|
||||
|
||||
if debug:
|
||||
print(f"fp32_init {fp32_init}")
|
||||
print(f"fp32_eval {fp32_eval}")
|
||||
|
||||
# here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram.
|
||||
# perfect world: fp32_init == 64<<10
|
||||
self.assertGreater(fp32_init, 59_000)
|
||||
# after eval should be no extra memory allocated - with a small margin (other than the peak
|
||||
# memory consumption for the forward calculation that gets recovered)
|
||||
# perfect world: fp32_eval == close to zero
|
||||
self.assertLess(fp32_eval, 5_000)
|
||||
|
||||
# 2. with mem metrics disabled
|
||||
trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False)
|
||||
metrics = trainer.evaluate()
|
||||
bf16_init = metrics["init_mem_gpu_alloc_delta"]
|
||||
bf16_eval = metrics["eval_mem_gpu_alloc_delta"]
|
||||
|
||||
if debug:
|
||||
print(f"bf16_init {bf16_init}")
|
||||
print(f"bf16_eval {bf16_eval}")
|
||||
|
||||
# here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0
|
||||
# perfect world: bf16_init == close to zero
|
||||
self.assertLess(bf16_init, 5_000)
|
||||
# here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back)
|
||||
# perfect world: fp32_init == 32<<10
|
||||
self.assertGreater(bf16_eval, 27_000)
|
||||
|
||||
# 3. relative comparison fp32 vs full bf16
|
||||
# should be about half of bf16_init
|
||||
# perfect world: fp32_init/2 == bf16_eval
|
||||
self.assertAlmostEqual(bf16_eval, fp32_init / 2, delta=5_000)
|
||||
|
||||
def test_no_wd_param_group(self):
|
||||
model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)]))
|
||||
trainer = Trainer(model=model)
|
||||
|
||||
Reference in New Issue
Block a user