[Trainer / GC] Add gradient_checkpointing_kwargs in trainer and training arguments (#27068)

* add `gradient_checkpointing_kwargs` in trainer and training arguments

* add comment

* add test - currently failing

* now tests pass
This commit is contained in:
Younes Belkada
2023-10-30 12:41:48 +01:00
committed by GitHub
parent e830495c1c
commit 5fbed2d7ca
3 changed files with 72 additions and 2 deletions

View File

@@ -283,6 +283,38 @@ if is_torch_available():
loss = nn.functional.mse_loss(y, labels)
return (loss, y, y) if self.double_output else (loss, y)
class RegressionPreTrainedModelWithGradientCheckpointing(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
supports_gradient_checkpointing = True
def __init__(self, config):
super().__init__(config)
self.layers = nn.ModuleList([nn.Linear(config.hidden_size, config.hidden_size) for _ in range(4)])
self.head = nn.Linear(config.hidden_size, 1)
self.gradient_checkpointing = False
self.double_output = config.double_output
def forward(self, input_x, labels=None, **kwargs):
y = input_x.unsqueeze(0)
for layer in self.layers:
if self.training and self.gradient_checkpointing:
outputs = self._gradient_checkpointing_func(layer.__call__, y)
else:
outputs = layer(y)
y = outputs * 3
logits = self.head(y)
if labels is None:
return (logits, logits) if self.double_output else (logits,)
loss = nn.functional.mse_loss(logits, labels)
return (loss, y, y) if self.double_output else (loss, y)
class RegressionRandomPreTrainedModel(PreTrainedModel):
config_class = RegressionModelConfig
base_model_prefix = "regression"
@@ -327,6 +359,7 @@ if is_torch_available():
a=0, b=0, double_output=False, train_len=64, eval_len=64, pretrained=True, keep_report_to=False, **kwargs
):
label_names = kwargs.get("label_names", None)
gradient_checkpointing = kwargs.get("gradient_checkpointing", False)
train_dataset = RegressionDataset(length=train_len, label_names=label_names)
eval_dataset = RegressionDataset(length=eval_len, label_names=label_names)
@@ -336,7 +369,13 @@ if is_torch_available():
else:
if pretrained:
config = RegressionModelConfig(a=a, b=b, double_output=double_output)
model = RegressionPreTrainedModel(config)
# We infer the correct model class if one uses gradient_checkpointing or not
target_cls = (
RegressionPreTrainedModel
if not gradient_checkpointing
else RegressionPreTrainedModelWithGradientCheckpointing
)
model = target_cls(config)
else:
model = RegressionModel(a=a, b=b, double_output=double_output)
@@ -548,6 +587,24 @@ class TrainerIntegrationPrerunTest(TestCasePlus, TrainerIntegrationCommon):
trainer.train()
self.check_trained_model(trainer.model)
def test_gradient_checkpointing(self):
trainer = get_regression_trainer(
per_device_train_batch_size=1,
learning_rate=0.1,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
previous_params = {k: v.detach().clone() for k, v in trainer.model.named_parameters()}
trainer.train()
# Check if model weights have been updated
for k, v in trainer.model.named_parameters():
self.assertFalse(
torch.allclose(previous_params[k], v, rtol=1e-4, atol=1e-4),
f"Model weights for {k} have not been updated",
)
def test_training_loss(self):
n_gpus = max(1, get_gpu_count())