Make gradient_checkpointing a training argument (#13657)

* Make gradient_checkpointing a training argument

* Update src/transformers/modeling_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update src/transformers/configuration_utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Fix tests

* Style

* document Gradient Checkpointing as a performance feature

* Small rename

* PoC for not using the config

* Adapt BC to new PoC

* Forgot to save

* Rollout changes to all other models

* Fix typo

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas@stason.org>
This commit is contained in:
Sylvain Gugger
2021-09-22 07:51:38 -04:00
committed by GitHub
parent 75f6641eaf
commit 27d4639779
96 changed files with 531 additions and 309 deletions

View File

@@ -20,6 +20,7 @@ import unittest
from transformers import DeiTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.models.auto import get_values
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
from .test_configuration_common import ConfigTester
@@ -340,7 +341,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
for model_class in self.all_model_classes:
# DeiTForImageClassificationWithTeacher supports inference-only
if (
model_class in MODEL_MAPPING.values()
model_class in get_values(MODEL_MAPPING)
or model_class.__name__ == "DeiTForImageClassificationWithTeacher"
):
continue
@@ -351,6 +352,27 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
loss = model(**inputs).loss
loss.backward()
def test_training_gradient_checkpointing(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if not self.model_tester.is_training:
return
config.use_cache = False
config.return_dict = True
for model_class in self.all_model_classes:
if model_class in get_values(MODEL_MAPPING) or not model_class.supports_gradient_checkpointing:
continue
# DeiTForImageClassificationWithTeacher supports inference-only
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss
loss.backward()
def test_for_image_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_image_classification(*config_and_inputs)