From 70996a5420f6b28cb0330e373b99f75893c8fbb3 Mon Sep 17 00:00:00 2001 From: Jamie DeAntonis <33379057+JamesDeAntonis@users.noreply.github.com> Date: Tue, 30 Nov 2021 21:00:47 -0500 Subject: [PATCH] WIP: Support for Training with BF16 (#13207) * started bf16 integration * minor changes * code now runs * style * lay foundation for bf16 testing * lay foundation for bf16 testing * start the tests * better bf16 check * style * 2 separate checkers - one for bf16 support, another for bf16+autocast * Update src/transformers/training_args.py Co-authored-by: Stas Bekman * a couple of comment resolutions * more comment resolutions * resolved a small bug * just some print statemtns * added todo marking * added a todo * adjust for API change s/fast_dtype/dtype/ * fix style * merge 2 bf16 util functions * bf16 now does scaling too * Add support for bfloat16 * Revert T5 layernorm to float32 This is based on the comment at https://github.com/huggingface/transformers/pull/14448/files#r752660929 and the PyTorch PR https://github.com/pytorch/pytorch/pull/66920 . * Add comment about conversion to float32 before returning the numpy data * Add comment about AMP-bfloat16 incompatibility * Fix formatting * typo * reformer / bf16 * cleanup * require at least pt-1.10 * fix * will deal with deepspeed separately * cleanup * revert * cleanup * fp16_full_eval and bf16_full_eval are separate modes * proper deprecation * cleanup * test and fixes * spelling * cleanup * add a note that this API is experimental Co-authored-by: jamie Co-authored-by: Stas Bekman Co-authored-by: Stas Bekman Co-authored-by: suriya Co-authored-by: Manuel R. Ciosici --- src/transformers/file_utils.py | 31 ++++++ src/transformers/modeling_utils.py | 2 +- src/transformers/models/t5/modeling_t5.py | 7 +- src/transformers/testing_utils.py | 9 ++ src/transformers/trainer.py | 113 ++++++++++++---------- src/transformers/trainer_pt_utils.py | 8 +- src/transformers/training_args.py | 66 ++++++++++--- tests/test_trainer.py | 76 +++++++++++++++ 8 files changed, 246 insertions(+), 66 deletions(-) diff --git a/src/transformers/file_utils.py b/src/transformers/file_utils.py index ed8f458356..d88da95dbb 100644 --- a/src/transformers/file_utils.py +++ b/src/transformers/file_utils.py @@ -320,6 +320,37 @@ def is_torch_cuda_available(): return False +def is_torch_bf16_available(): + if is_torch_available(): + import torch + + # since currently no utility function is available we build our own. + # some bits come from https://github.com/pytorch/pytorch/blob/2289a12f21c54da93bf5d696e3f9aea83dd9c10d/torch/testing/_internal/common_cuda.py#L51 + # with additional check for torch version + # to succeed: + # 1. the hardware needs to support bf16 (arch >= Ampere) + # 2. torch >= 1.10 (1.9 should be enough for AMP API has changed in 1.10, so using 1.10 as minimal) + # 3. CUDA >= 11 + # 4. torch.autocast exists + # XXX: one problem here is that it may give invalid results on mixed gpus setup, so it's + # really only correct for the 0th gpu (or currently set default device if different from 0) + + if not torch.cuda.is_available() or torch.version.cuda is None: + return False + if torch.cuda.get_device_properties(torch.cuda.current_device()).major < 8: + return False + if int(torch.version.cuda.split(".")[0]) < 11: + return False + if not version.parse(torch.__version__) >= version.parse("1.10"): + return False + if not hasattr(torch, "autocast"): + return False + + return True + else: + return False + + _torch_fx_available = _torch_onnx_dict_inputs_support_available = False if _torch_available: torch_version = version.parse(importlib_metadata.version("torch")) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9ec64ebb73..177f0cb79a 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -233,7 +233,7 @@ class ModuleUtilsMixin: if self.dtype == torch.float16: encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e4 - elif self.dtype == torch.float32: + elif self.dtype in [torch.bfloat16, torch.float32]: encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * -1e9 else: raise ValueError( diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 78ccd07236..e5c1f340ea 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -242,9 +242,10 @@ class T5LayerNorm(nn.Module): variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) - # convert into float16 if necessary - if self.weight.dtype == torch.float16: - hidden_states = hidden_states.to(torch.float16) + # convert into half-precision if necessary + if self.weight.dtype in [torch.float16, torch.bfloat16]: + hidden_states = hidden_states.to(self.weight.dtype) + return self.weight * hidden_states diff --git a/src/transformers/testing_utils.py b/src/transformers/testing_utils.py index 63700e0147..e5f96d830e 100644 --- a/src/transformers/testing_utils.py +++ b/src/transformers/testing_utils.py @@ -49,6 +49,7 @@ from .file_utils import ( is_timm_available, is_tokenizers_available, is_torch_available, + is_torch_bf16_available, is_torch_tpu_available, is_torchaudio_available, is_vision_available, @@ -493,6 +494,14 @@ def require_torch_gpu(test_case): return test_case +def require_torch_bf16(test_case): + """Decorator marking a test that requires CUDA hardware supporting bf16 and PyTorch >= 1.10.""" + if not is_torch_bf16_available(): + return unittest.skip("test requires CUDA hardware supporting bf16 and PyTorch >= 1.10")(test_case) + else: + return test_case + + def require_datasets(test_case): """Decorator marking a test that requires datasets.""" diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 7e6d500265..3b3da57ca5 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -353,13 +353,13 @@ class Trainer: # 1. MP - since we are trying to fit a much bigger than 1 gpu model # 2. fp16-enabled DeepSpeed loads the model in half the size and it doesn't need .to() anyway, # and we only use deepspeed for training at the moment - # 3. full fp16 eval - since the model needs to be half'ed first + # 3. full bf16 or fp16 eval - since the model needs to be cast to the right dtype first # 4. Sharded DDP - same as MP self.place_model_on_device = args.place_model_on_device if ( self.is_model_parallel or args.deepspeed - or (args.fp16_full_eval and not args.do_train) + or ((args.fp16_full_eval or args.bf16_full_eval) and not args.do_train) or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) ): self.place_model_on_device = False @@ -424,18 +424,24 @@ class Trainer: # Mixed precision setup self.use_apex = False self.use_amp = False - self.fp16_backend = None - if args.fp16: - if args.fp16_backend == "auto": - self.fp16_backend = "amp" if _is_native_amp_available else "apex" - else: - self.fp16_backend = args.fp16_backend - logger.info(f"Using {self.fp16_backend} fp16 backend") + if args.fp16 or args.bf16: + if args.half_precision_backend == "auto": + if _is_native_amp_available: + args.half_precision_backend = "amp" + else: + if args.bf16: + raise ValueError("Tried to use `bf16` but native amp is not available") + else: + args.half_precision_backend = "apex" + logger.info(f"Using {args.half_precision_backend} half precision backend") - if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 - if self.fp16_backend == "amp": + self.do_grad_scaling = False + if (args.fp16 or args.bf16) and not args.deepspeed: # deepspeed manages its own half precision + if args.half_precision_backend == "amp": self.use_amp = True + self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16 + self.do_grad_scaling = True if is_sagemaker_mp_enabled(): self.scaler = smp.amp.GradScaler() elif self.sharded_ddp is not None: @@ -975,7 +981,7 @@ class Trainer: if self.sharded_ddp == ShardedDDPOption.SIMPLE: model = ShardedDDP(model, self.optimizer) else: - mixed_precision = self.args.fp16 + mixed_precision = self.args.fp16 or self.args.bf16 cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 # XXX: Breaking the self.model convention but I see no way around it for now. @@ -1043,7 +1049,7 @@ class Trainer: # do_train is not a reliable argument, as it might not be set and .train() still called, so # the following is a workaround: - if args.fp16_full_eval and not args.do_train: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: self._move_model_to_device(self.model, args.device) if "model_path" in kwargs: @@ -1341,7 +1347,7 @@ class Trainer: if args.max_grad_norm is not None and args.max_grad_norm > 0 and not self.deepspeed: # deepspeed does its own clipping - if self.use_amp: + if self.do_grad_scaling: # AMP: gradients need unscaling self.scaler.unscale_(self.optimizer) @@ -1364,7 +1370,7 @@ class Trainer: pass # called outside the loop elif is_torch_tpu_available(): xm.optimizer_step(self.optimizer) - elif self.use_amp: + elif self.do_grad_scaling: scale_before = self.scaler.get_scale() self.scaler.step(self.optimizer) self.scaler.update() @@ -1588,7 +1594,7 @@ class Trainer: with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) - if self.use_amp: + if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) elif self.args.should_save and not self.deepspeed: # deepspeed.save_checkpoint above saves model/optim/sched @@ -1596,7 +1602,7 @@ class Trainer: with warnings.catch_warnings(record=True) as caught_warnings: torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME)) reissue_pt_warnings(caught_warnings) - if self.use_amp: + if self.do_grad_scaling: torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) # Determine the new best metric / best model checkpoint @@ -1684,7 +1690,7 @@ class Trainer: with warnings.catch_warnings(record=True) as caught_warnings: self.lr_scheduler.load_state_dict(torch.load(os.path.join(checkpoint, SCHEDULER_NAME))) reissue_pt_warnings(caught_warnings) - if self.use_amp and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): + if self.do_grad_scaling and os.path.isfile(os.path.join(checkpoint, SCALER_NAME)): self.scaler.load_state_dict(torch.load(os.path.join(checkpoint, SCALER_NAME))) def hyperparameter_search( @@ -1846,12 +1852,12 @@ class Trainer: inputs = self._prepare_inputs(inputs) if is_sagemaker_mp_enabled(): - scaler = self.scaler if self.use_amp else None + scaler = self.scaler if self.do_grad_scaling else None loss_mb = smp_forward_backward(model, inputs, self.args.gradient_accumulation_steps, scaler=scaler) return loss_mb.reduce_mean().detach().to(self.args.device) if self.use_amp: - with autocast(): + with autocast(dtype=self.amp_dtype): loss = self.compute_loss(model, inputs) else: loss = self.compute_loss(model, inputs) @@ -1863,7 +1869,7 @@ class Trainer: # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward` loss = loss / self.args.gradient_accumulation_steps - if self.use_amp: + if self.do_grad_scaling: self.scaler.scale(loss).backward() elif self.use_apex: with amp.scale_loss(loss, self.optimizer) as scaled_loss: @@ -2220,12 +2226,12 @@ class Trainer: Works both with or without labels. """ - prediction_loss_only = ( - prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only - ) + args = self.args + + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train init deepspeed here - if self.args.deepspeed and not self.deepspeed: + if args.deepspeed and not self.deepspeed: # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # from the checkpoint eventually @@ -2238,10 +2244,13 @@ class Trainer: model = self._wrap_model(self.model, training=False) - # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while - # ``train`` is running, halve it first and then put on device - if not self.is_in_train and self.args.fp16_full_eval: - model = model.half().to(self.args.device) + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = dataloader.batch_size @@ -2259,9 +2268,9 @@ class Trainer: eval_dataset = dataloader.dataset if is_torch_tpu_available(): - dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) - if self.args.past_index >= 0: + if args.past_index >= 0: self._past = None # Initialize containers @@ -2301,10 +2310,10 @@ class Trainer: labels = self._pad_across_processes(labels) labels = self._nested_gather(labels) labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) - self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: if losses_host is not None: losses = nested_numpify(losses_host) all_losses = losses if all_losses is None else np.concatenate((all_losses, losses), axis=0) @@ -2320,7 +2329,7 @@ class Trainer: # Set back to None to begin a new accumulation losses_host, preds_host, labels_host = None, None, None - if self.args.past_index and hasattr(self, "_past"): + if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") @@ -2492,11 +2501,12 @@ class Trainer: else: if has_labels: if self.use_amp: - with autocast(): + with autocast(dtype=self.amp_dtype): loss, outputs = self.compute_loss(model, inputs, return_outputs=True) else: loss, outputs = self.compute_loss(model, inputs, return_outputs=True) loss = loss.mean().detach() + if isinstance(outputs, dict): logits = tuple(v for k, v in outputs.items() if k not in ignore_keys + ["loss"]) else: @@ -2504,7 +2514,7 @@ class Trainer: else: loss = None if self.use_amp: - with autocast(): + with autocast(dtype=self.amp_dtype): outputs = model(**inputs) else: outputs = model(**inputs) @@ -2719,14 +2729,14 @@ class Trainer: Works both with or without labels. """ + args = self.args + if not isinstance(dataloader.dataset, collections.abc.Sized): raise ValueError("dataset must implement __len__") - prediction_loss_only = ( - prediction_loss_only if prediction_loss_only is not None else self.args.prediction_loss_only - ) + prediction_loss_only = prediction_loss_only if prediction_loss_only is not None else args.prediction_loss_only # if eval is called w/o train init deepspeed here - if self.args.deepspeed and not self.deepspeed: + if args.deepspeed and not self.deepspeed: # XXX: eval doesn't have `resume_from_checkpoint` arg but we should be able to do eval # from the checkpoint eventually @@ -2742,10 +2752,13 @@ class Trainer: model = self._wrap_model(self.model, training=False) - # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while - # ``train`` is running, halve it first and then put on device - if not self.is_in_train and self.args.fp16_full_eval: - model = model.half().to(self.args.device) + # if full fp16 or bf16 eval is wanted and this ``evaluation`` or ``predict`` isn't called + # while ``train`` is running, cast it to the right dtype first and then put on device + if not self.is_in_train: + if args.fp16_full_eval: + model = model.to(dtype=torch.float16, device=args.device) + elif args.bf16_full_eval: + model = model.to(dtype=torch.bfloat16, device=args.device) batch_size = dataloader.batch_size num_examples = self.num_examples(dataloader) @@ -2756,7 +2769,7 @@ class Trainer: preds_host: Union[torch.Tensor, List[torch.Tensor]] = None labels_host: Union[torch.Tensor, List[torch.Tensor]] = None - world_size = max(1, self.args.world_size) + world_size = max(1, args.world_size) eval_losses_gatherer = DistributedTensorGatherer(world_size, num_examples, make_multiple_of=batch_size) if not prediction_loss_only: @@ -2771,9 +2784,9 @@ class Trainer: model.eval() if is_torch_tpu_available(): - dataloader = pl.ParallelLoader(dataloader, [self.args.device]).per_device_loader(self.args.device) + dataloader = pl.ParallelLoader(dataloader, [args.device]).per_device_loader(args.device) - if self.args.past_index >= 0: + if args.past_index >= 0: self._past = None self.callback_handler.eval_dataloader = dataloader @@ -2787,10 +2800,10 @@ class Trainer: preds_host = logits if preds_host is None else nested_concat(preds_host, logits, padding_index=-100) if labels is not None: labels_host = labels if labels_host is None else nested_concat(labels_host, labels, padding_index=-100) - self.control = self.callback_handler.on_prediction_step(self.args, self.state, self.control) + self.control = self.callback_handler.on_prediction_step(args, self.state, self.control) # Gather all tensors and put them back on the CPU if we have done enough accumulation steps. - if self.args.eval_accumulation_steps is not None and (step + 1) % self.args.eval_accumulation_steps == 0: + if args.eval_accumulation_steps is not None and (step + 1) % args.eval_accumulation_steps == 0: eval_losses_gatherer.add_arrays(self._gather_and_numpify(losses_host, "eval_losses")) if not prediction_loss_only: preds_gatherer.add_arrays(self._gather_and_numpify(preds_host, "eval_preds")) @@ -2799,7 +2812,7 @@ class Trainer: # Set back to None to begin a new accumulation losses_host, preds_host, labels_host = None, None, None - if self.args.past_index and hasattr(self, "_past"): + if args.past_index and hasattr(self, "_past"): # Clean the state at the end of the evaluation loop delattr(self, "_past") diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index a08c2ddd64..c9c61ac7ba 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -136,7 +136,13 @@ def nested_numpify(tensors): "Numpify `tensors` (even if it's a nested list/tuple of tensors)." if isinstance(tensors, (list, tuple)): return type(tensors)(nested_numpify(t) for t in tensors) - return tensors.cpu().numpy() + t = tensors.cpu() + if t.dtype == torch.bfloat16: + # As of Numpy 1.21.4, NumPy does not support bfloat16 (see + # https://github.com/numpy/numpy/blob/a47ecdea856986cd60eabbd53265c2ca5916ad5d/doc/source/user/basics.types.rst ). + # Until Numpy adds bfloat16, we must convert float32. + t = t.to(torch.float32) + return t.numpy() def nested_detach(tensors): diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 644f83665e..74a01aace4 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -207,18 +207,26 @@ class TrainingArguments: Random seed that will be set at the beginning of training. To ensure reproducibility across runs, use the :func:`~transformers.Trainer.model_init` function to instantiate the model if it has some randomly initialized parameters. + bf16 (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use bf16 16-bit (mixed) precision training instead of 32-bit training. Requires Ampere or higher + NVIDIA architecture. This is an experimental API and it may change. fp16 (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use 16-bit (mixed) precision training instead of 32-bit training. + Whether to use fp16 16-bit (mixed) precision training instead of 32-bit training. fp16_opt_level (:obj:`str`, `optional`, defaults to 'O1'): For :obj:`fp16` training, Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']. See details on the `Apex documentation `__. fp16_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`): + This argument is deprecated. Use ``half_precision_backend`` instead. + half_precision_backend (:obj:`str`, `optional`, defaults to :obj:`"auto"`): The backend to use for mixed precision training. Must be one of :obj:`"auto"`, :obj:`"amp"` or :obj:`"apex"`. :obj:`"auto"` will use AMP or APEX depending on the PyTorch version detected, while the other choices will force the requested backend. + bf16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether to use full bfloat16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. This is an experimental API and it may change. fp16_full_eval (:obj:`bool`, `optional`, defaults to :obj:`False`): - Whether to use full 16-bit precision evaluation instead of 32-bit. This will be faster and save memory but - can harm metric values. + Whether to use full float16 evaluation instead of 32-bit. This will be faster and save memory but can harm + metric values. local_rank (:obj:`int`, `optional`, defaults to -1): Rank of the process during distributed training. xpu_backend (:obj:`str`, `optional`): @@ -507,10 +515,15 @@ class TrainingArguments: ) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) - + bf16: bool = field( + default=False, + metadata={ + "help": "Whether to use bf16 (mixed) precision instead of 32-bit. Requires Ampere or higher NVIDIA architecture. This is an experimental API and it may change." + }, + ) fp16: bool = field( default=False, - metadata={"help": "Whether to use 16-bit (mixed) precision instead of 32-bit"}, + metadata={"help": "Whether to use fp16 (mixed) precision instead of 32-bit"}, ) fp16_opt_level: str = field( default="O1", @@ -521,13 +534,19 @@ class TrainingArguments: ) }, ) - fp16_backend: str = field( + half_precision_backend: str = field( default="auto", - metadata={"help": "The backend to be used for mixed precision.", "choices": ["auto", "amp", "apex"]}, + metadata={"help": "The backend to be used for half precision.", "choices": ["auto", "amp", "apex"]}, + ) + bf16_full_eval: bool = field( + default=False, + metadata={ + "help": "Whether to use full bfloat16 evaluation instead of 32-bit. This is an experimental API and it may change." + }, ) fp16_full_eval: bool = field( default=False, - metadata={"help": "Whether to use full 16-bit precision evaluation instead of 32-bit"}, + metadata={"help": "Whether to use full float16 evaluation instead of 32-bit"}, ) local_rank: int = field(default=-1, metadata={"help": "For distributed training: local_rank"}) xpu_backend: str = field( @@ -666,6 +685,10 @@ class TrainingArguments: }, ) # Deprecated arguments + fp16_backend: str = field( + default="auto", + metadata={"help": "Deprecated. Use half_precision_backend instead", "choices": ["auto", "amp", "apex"]}, + ) push_to_hub_model_id: str = field( default=None, metadata={"help": "The name of the repository to which push the `Trainer`."} ) @@ -754,10 +777,31 @@ class TrainingArguments: if self.run_name is None: self.run_name = self.output_dir - if is_torch_available() and self.device.type != "cuda" and (self.fp16 or self.fp16_full_eval): - raise ValueError( - "Mixed precision training with AMP or APEX (`--fp16`) and FP16 evaluation can only be used on CUDA devices." + if self.fp16_backend and self.fp16_backend != "auto": + warnings.warn( + "`fp16_backend` is deprecated and will be removed in version 5 of 🤗 Transformers. Use `half_precision_backend` instead", + FutureWarning, ) + self.half_precision_backend = self.fp16_backend + + if self.fp16 and self.bf16: + raise ValueError("At most one of fp16 and bf16 can be True, but not both") + if self.bf16: + if self.half_precision_backend == "apex": + raise ValueError( + " `--half_precision_backend apex`: bf16 is not supported by apex. Use `--half_precision_backend amp` instead" + ) + if not (self.sharded_ddp == "" or not self.sharded_ddp): + raise ValueError("sharded_ddp is not supported with bf16") + if ( + is_torch_available() + and self.device.type != "cuda" + and (self.fp16 or self.fp16_full_eval or self.bf16 or self.bf16_full_eval) + ): + raise ValueError( + "Mixed precision training with AMP or APEX (`--fp16` or `--bf16`) and half precision evaluation (`--fp16_full_eval` or `--bf16_full_eval`) can only be used on CUDA devices." + ) + if self.report_to is None: logger.info( "The default value for the training argument `--report_to` will change in v5 (from all installed " diff --git a/tests/test_trainer.py b/tests/test_trainer.py index 5b2029a299..4ccc5122e7 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -53,6 +53,7 @@ from transformers.testing_utils import ( require_sigopt, require_tokenizers, require_torch, + require_torch_bf16, require_torch_gpu, require_torch_multi_gpu, require_torch_non_multi_gpu, @@ -476,6 +477,21 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): self.assertFalse(torch.allclose(trainer.model.b, b)) self.assertGreater(trainer.optimizer.state_dict()["param_groups"][0]["lr"], 0) + @require_torch_gpu + @require_torch_bf16 + def test_mixed_bf16(self): + + # very basic test + trainer = get_regression_trainer(learning_rate=0.1, bf16=True) + trainer.train() + self.check_trained_model(trainer.model) + + # --bf16 --half_precision_backend apex can't be used together + with self.assertRaises(ValueError): + trainer = get_regression_trainer(learning_rate=0.1, bf16=True, half_precision_backend="apex") + + # will add more specific tests once there are some bugs to fix + @require_torch @require_sentencepiece @@ -1323,6 +1339,66 @@ class TrainerIntegrationTest(TestCasePlus, TrainerIntegrationCommon): # perfect world: fp32_init/2 == fp16_eval self.assertAlmostEqual(fp16_eval, fp32_init / 2, delta=5_000) + @require_torch_gpu + @require_torch_bf16 + def test_bf16_full_eval(self): + # note: most of the logic is the same as test_fp16_full_eval + + # this is a sensitive test so let's keep debugging printouts in place for quick diagnosis. + # it's using pretty large safety margins, but small enough to detect broken functionality. + debug = 0 + n_gpus = get_gpu_count() + + bs = 8 + eval_len = 16 * n_gpus + # make the params somewhat big so that there will be enough RAM consumed to be able to + # measure things. We should get about 64KB for a+b in fp32 + a = torch.ones(1000, bs) + 0.001 + b = torch.ones(1000, bs) - 0.001 + + # 1. with mem metrics enabled + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, skip_memory_metrics=False) + metrics = trainer.evaluate() + del trainer + gc.collect() + + fp32_init = metrics["init_mem_gpu_alloc_delta"] + fp32_eval = metrics["eval_mem_gpu_alloc_delta"] + + if debug: + print(f"fp32_init {fp32_init}") + print(f"fp32_eval {fp32_eval}") + + # here we expect the model to be preloaded in trainer.__init__ and consume around 64K gpu ram. + # perfect world: fp32_init == 64<<10 + self.assertGreater(fp32_init, 59_000) + # after eval should be no extra memory allocated - with a small margin (other than the peak + # memory consumption for the forward calculation that gets recovered) + # perfect world: fp32_eval == close to zero + self.assertLess(fp32_eval, 5_000) + + # 2. with mem metrics disabled + trainer = get_regression_trainer(a=a, b=b, eval_len=eval_len, bf16_full_eval=True, skip_memory_metrics=False) + metrics = trainer.evaluate() + bf16_init = metrics["init_mem_gpu_alloc_delta"] + bf16_eval = metrics["eval_mem_gpu_alloc_delta"] + + if debug: + print(f"bf16_init {bf16_init}") + print(f"bf16_eval {bf16_eval}") + + # here we expect the model to not be preloaded in trainer.__init__, so with a small margin it should be close to 0 + # perfect world: bf16_init == close to zero + self.assertLess(bf16_init, 5_000) + # here we put the model on device in eval and only `half()` of it, i.e. about 32K,(again we ignore the peak margin which gets returned back) + # perfect world: fp32_init == 32<<10 + self.assertGreater(bf16_eval, 27_000) + + # 3. relative comparison fp32 vs full bf16 + # should be about half of bf16_init + # perfect world: fp32_init/2 == bf16_eval + self.assertAlmostEqual(bf16_eval, fp32_init / 2, delta=5_000) + def test_no_wd_param_group(self): model = nn.Sequential(TstLayer(128), nn.ModuleList([TstLayer(128), TstLayer(128)])) trainer = Trainer(model=model)