From 5fbed2d7ca6c8690f18b5c33f7b166de1c14fd26 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Mon, 30 Oct 2023 12:41:48 +0100 Subject: [PATCH] [`Trainer` / `GC`] Add `gradient_checkpointing_kwargs` in trainer and training arguments (#27068) * add `gradient_checkpointing_kwargs` in trainer and training arguments * add comment * add test - currently failing * now tests pass --- src/transformers/trainer.py | 7 +++- src/transformers/training_args.py | 8 +++++ tests/trainer/test_trainer.py | 59 ++++++++++++++++++++++++++++++- 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 25941ff0c7..06879cbce7 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1616,7 +1616,12 @@ class Trainer: # Activate gradient checkpointing if needed if args.gradient_checkpointing: - self.model.gradient_checkpointing_enable() + if args.gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {} + else: + gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs + + self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) model = self._wrap_model(self.model_wrapped) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index da7fe2b61e..cc8a3de56b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -572,6 +572,8 @@ class TrainingArguments: Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished. gradient_checkpointing (`bool`, *optional*, defaults to `False`): If True, use gradient checkpointing to save memory at the expense of slower backward pass. + gradient_checkpointing_args (`dict`, *optional*, defaults to `None`): + Key word arguments to be passed to the `gradient_checkpointing_enable` method. include_inputs_for_metrics (`bool`, *optional*, defaults to `False`): Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics that need inputs, predictions and references for scoring calculation in Metric class. @@ -1119,6 +1121,12 @@ class TrainingArguments: "help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass." }, ) + gradient_checkpointing_kwargs: dict = field( + default=None, + metadata={ + "help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`." + }, + ) include_inputs_for_metrics: bool = field( default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."} ) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 6400852e62..624d3833f4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -283,6 +283,38 @@ if is_torch_available(): loss = nn.functional.mse_loss(y, labels) return (loss, y, y) if self.double_output else (loss, y) + class RegressionPreTrainedModelWithGradientCheckpointing(PreTrainedModel): + config_class = RegressionModelConfig + base_model_prefix = "regression" + supports_gradient_checkpointing = True + + def __init__(self, config): + super().__init__(config) + self.layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) for _ in range(4)]) + self.head = nn.Linear(config.hidden_size, 1) + self.gradient_checkpointing = False + self.double_output = config.double_output + + def forward(self, input_x, labels=None, **kwargs): + y = input_x.unsqueeze(0) + + for layer in self.layers: + if self.training and self.gradient_checkpointing: + outputs = self._gradient_checkpointing_func(layer.__call__, y) + else: + outputs = layer(y) + + y = outputs * 3 + + logits = self.head(y) + + if labels is None: + return (logits, logits) if self.double_output else (logits,) + + loss = nn.functional.mse_loss(logits, labels) + + return (loss, y, y) if self.double_output else (loss, y) + class RegressionRandomPreTrainedModel(PreTrainedModel): config_class = RegressionModelConfig base_model_prefix = "regression" @@ -327,6 +359,7 @@ if is_torch_available(): a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, keep_report_to=False, **kwargs ): label_names = kwargs.get("label_names", None) + gradient_checkpointing = kwargs.get("gradient_checkpointing", False) train_dataset = RegressionDataset(length=train_len, label_names=label_names) eval_dataset = RegressionDataset(length=eval_len, label_names=label_names) @@ -336,7 +369,13 @@ if is_torch_available(): else: if pretrained: config = RegressionModelConfig(a=a, b=b, double_output=double_output) - model = RegressionPreTrainedModel(config) + # We infer the correct model class if one uses gradient_checkpointing or not + target_cls = ( + RegressionPreTrainedModel + if not gradient_checkpointing + else RegressionPreTrainedModelWithGradientCheckpointing + ) + model = target_cls(config) else: model = RegressionModel(a=a, b=b, double_output=double_output) @@ -548,6 +587,24 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon): trainer.train() self.check_trained_model(trainer.model) + def test_gradient_checkpointing(self): + trainer = get_regression_trainer( + per_device_train_batch_size=1, + learning_rate=0.1, + gradient_checkpointing=True, + gradient_checkpointing_kwargs={"use_reentrant": False}, + ) + previous_params = {k: v.detach().clone() for k, v in trainer.model.named_parameters()} + + trainer.train() + + # Check if model weights have been updated + for k, v in trainer.model.named_parameters(): + self.assertFalse( + torch.allclose(previous_params[k], v, rtol=1e-4, atol=1e-4), + f"Model weights for {k} have not been updated", + ) + def test_training_loss(self): n_gpus = max(1, get_gpu_count())