[Trainer / GC] Add gradient_checkpointing_kwargs in trainer and training arguments (#27068)
* add `gradient_checkpointing_kwargs` in trainer and training arguments * add comment * add test - currently failing * now tests pass
This commit is contained in:
@@ -1616,7 +1616,12 @@ class Trainer:
|
|||||||
|
|
||||||
# Activate gradient checkpointing if needed
|
# Activate gradient checkpointing if needed
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable()
|
if args.gradient_checkpointing_kwargs is None:
|
||||||
|
gradient_checkpointing_kwargs = {}
|
||||||
|
else:
|
||||||
|
gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs
|
||||||
|
|
||||||
|
self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
|
||||||
|
|
||||||
model = self._wrap_model(self.model_wrapped)
|
model = self._wrap_model(self.model_wrapped)
|
||||||
|
|
||||||
|
|||||||
@@ -572,6 +572,8 @@ class TrainingArguments:
|
|||||||
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
|
Unless this is `True`, the `Trainer` will skip pushing a checkpoint when the previous push is not finished.
|
||||||
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
|
gradient_checkpointing (`bool`, *optional*, defaults to `False`):
|
||||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||||
|
gradient_checkpointing_args (`dict`, *optional*, defaults to `None`):
|
||||||
|
Key word arguments to be passed to the `gradient_checkpointing_enable` method.
|
||||||
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
include_inputs_for_metrics (`bool`, *optional*, defaults to `False`):
|
||||||
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
Whether or not the inputs will be passed to the `compute_metrics` function. This is intended for metrics
|
||||||
that need inputs, predictions and references for scoring calculation in Metric class.
|
that need inputs, predictions and references for scoring calculation in Metric class.
|
||||||
@@ -1119,6 +1121,12 @@ class TrainingArguments:
|
|||||||
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
"help": "If True, use gradient checkpointing to save memory at the expense of slower backward pass."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
gradient_checkpointing_kwargs: dict = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Gradient checkpointing key word arguments such as `use_reentrant`. Will be passed to `torch.utils.checkpoint.checkpoint` through `model.gradient_checkpointing_enable`."
|
||||||
|
},
|
||||||
|
)
|
||||||
include_inputs_for_metrics: bool = field(
|
include_inputs_for_metrics: bool = field(
|
||||||
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
|
default=False, metadata={"help": "Whether or not the inputs will be passed to the `compute_metrics` function."}
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -283,6 +283,38 @@ if is_torch_available():
|
|||||||
loss = nn.functional.mse_loss(y, labels)
|
loss = nn.functional.mse_loss(y, labels)
|
||||||
return (loss, y, y) if self.double_output else (loss, y)
|
return (loss, y, y) if self.double_output else (loss, y)
|
||||||
|
|
||||||
|
class RegressionPreTrainedModelWithGradientCheckpointing(PreTrainedModel):
|
||||||
|
config_class = RegressionModelConfig
|
||||||
|
base_model_prefix = "regression"
|
||||||
|
supports_gradient_checkpointing = True
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) for _ in range(4)])
|
||||||
|
self.head = nn.Linear(config.hidden_size, 1)
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
self.double_output = config.double_output
|
||||||
|
|
||||||
|
def forward(self, input_x, labels=None, **kwargs):
|
||||||
|
y = input_x.unsqueeze(0)
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
if self.training and self.gradient_checkpointing:
|
||||||
|
outputs = self._gradient_checkpointing_func(layer.__call__, y)
|
||||||
|
else:
|
||||||
|
outputs = layer(y)
|
||||||
|
|
||||||
|
y = outputs * 3
|
||||||
|
|
||||||
|
logits = self.head(y)
|
||||||
|
|
||||||
|
if labels is None:
|
||||||
|
return (logits, logits) if self.double_output else (logits,)
|
||||||
|
|
||||||
|
loss = nn.functional.mse_loss(logits, labels)
|
||||||
|
|
||||||
|
return (loss, y, y) if self.double_output else (loss, y)
|
||||||
|
|
||||||
class RegressionRandomPreTrainedModel(PreTrainedModel):
|
class RegressionRandomPreTrainedModel(PreTrainedModel):
|
||||||
config_class = RegressionModelConfig
|
config_class = RegressionModelConfig
|
||||||
base_model_prefix = "regression"
|
base_model_prefix = "regression"
|
||||||
@@ -327,6 +359,7 @@ if is_torch_available():
|
|||||||
a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, keep_report_to=False, **kwargs
|
a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, keep_report_to=False, **kwargs
|
||||||
):
|
):
|
||||||
label_names = kwargs.get("label_names", None)
|
label_names = kwargs.get("label_names", None)
|
||||||
|
gradient_checkpointing = kwargs.get("gradient_checkpointing", False)
|
||||||
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
|
||||||
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
|
||||||
|
|
||||||
@@ -336,7 +369,13 @@ if is_torch_available():
|
|||||||
else:
|
else:
|
||||||
if pretrained:
|
if pretrained:
|
||||||
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
|
||||||
model = RegressionPreTrainedModel(config)
|
# We infer the correct model class if one uses gradient_checkpointing or not
|
||||||
|
target_cls = (
|
||||||
|
RegressionPreTrainedModel
|
||||||
|
if not gradient_checkpointing
|
||||||
|
else RegressionPreTrainedModelWithGradientCheckpointing
|
||||||
|
)
|
||||||
|
model = target_cls(config)
|
||||||
else:
|
else:
|
||||||
model = RegressionModel(a=a, b=b, double_output=double_output)
|
model = RegressionModel(a=a, b=b, double_output=double_output)
|
||||||
|
|
||||||
@@ -548,6 +587,24 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
|
|||||||
trainer.train()
|
trainer.train()
|
||||||
self.check_trained_model(trainer.model)
|
self.check_trained_model(trainer.model)
|
||||||
|
|
||||||
|
def test_gradient_checkpointing(self):
|
||||||
|
trainer = get_regression_trainer(
|
||||||
|
per_device_train_batch_size=1,
|
||||||
|
learning_rate=0.1,
|
||||||
|
gradient_checkpointing=True,
|
||||||
|
gradient_checkpointing_kwargs={"use_reentrant": False},
|
||||||
|
)
|
||||||
|
previous_params = {k: v.detach().clone() for k, v in trainer.model.named_parameters()}
|
||||||
|
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
|
# Check if model weights have been updated
|
||||||
|
for k, v in trainer.model.named_parameters():
|
||||||
|
self.assertFalse(
|
||||||
|
torch.allclose(previous_params[k], v, rtol=1e-4, atol=1e-4),
|
||||||
|
f"Model weights for {k} have not been updated",
|
||||||
|
)
|
||||||
|
|
||||||
def test_training_loss(self):
|
def test_training_loss(self):
|
||||||
n_gpus = max(1, get_gpu_count())
|
n_gpus = max(1, get_gpu_count())
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user