From 040fd471621e27e961f4ab87c03e3202d173cd6c Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Tue, 16 Nov 2021 08:58:42 -0500 Subject: [PATCH] Fix gradient_checkpointing backward compatibility (#14408) * Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten --- src/transformers/modeling_utils.py | 17 ++++++++++++---- src/transformers/models/detr/modeling_detr.py | 1 - .../models/layoutlmv2/modeling_layoutlmv2.py | 1 - tests/test_modeling_beit.py | 1 + tests/test_modeling_common.py | 20 +++++++++++++++++++ tests/test_modeling_deit.py | 1 + tests/test_modeling_unispeech_sat.py | 6 ++++++ tests/test_modeling_wav2vec2.py | 6 ++++++ 8 files changed, 47 insertions(+), 6 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 213f5b57ee..f45c11087f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -412,6 +412,17 @@ class ModuleUtilsMixin: return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings) +def gradient_checkpointing_hook(module, _): + # Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a + # proper post_init method. + if getattr(module.config, "gradient_checkpointing", False): + module.gradient_checkpointing_enable() + # Remove the attribute now that is has been consumed, so it's no saved in the config. + delattr(module.config, "gradient_checkpointing") + # The hook will remove itself after the first execution + module._gradient_checkpointing_hook.remove() + + class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin): r""" Base class for all models. @@ -479,10 +490,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix # Save config and origin of the pretrained weights if given in model self.config = config self.name_or_path = config.name_or_path - if getattr(self.config, "gradient_checkpointing", False): - self.gradient_checkpointing_enable() - # Remove the attribute now that is has been consumed, so it's no saved in the config. - delattr(self.config, "gradient_checkpointing") + if self.supports_gradient_checkpointing: + self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook) @classmethod def _from_config(cls, config, **kwargs): diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 5e95cc3f32..70287626b2 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -784,7 +784,6 @@ class DetrClassificationHead(nn.Module): class DetrPreTrainedModel(PreTrainedModel): config_class = DetrConfig base_model_prefix = "model" - supports_gradient_checkpointing = True def _init_weights(self, module): std = self.config.init_std diff --git a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py index 653d6ea627..e80029a300 100755 --- a/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py +++ b/src/transformers/models/layoutlmv2/modeling_layoutlmv2.py @@ -504,7 +504,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel): config_class = LayoutLMv2Config pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST base_model_prefix = "layoutlmv2" - supports_gradient_checkpointing = True _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): diff --git a/tests/test_modeling_beit.py b/tests/test_modeling_beit.py index c38f956cd8..9ead09a7d3 100644 --- a/tests/test_modeling_beit.py +++ b/tests/test_modeling_beit.py @@ -265,6 +265,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase): [self.model_tester.batch_size, height, width], device=torch_device ).long() model = model_class(config) + model.gradient_checkpointing_enable() model.to(torch_device) model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index af002cea63..49027d3f7e 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -213,6 +213,25 @@ class ModelTesterMixin: ) self.assertTrue(len(load_result.unexpected_keys) == 0) + def test_gradient_checkpointing_backward_compatibility(self): + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + + for model_class in self.all_model_classes: + if not model_class.supports_gradient_checkpointing: + continue + + config.gradient_checkpointing = True + model = model_class(config) + # Model does not have gradient checkpointing activated yet, it will be done at the first forward. + self.assertFalse(model.is_gradient_checkpointing) + + model.to(torch_device) + inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) + _ = model(**inputs) + + # Model has gradient checkpointing activated after the first forward. + self.assertTrue(model.is_gradient_checkpointing) + def test_gradient_checkpointing_enable_disable(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() @@ -418,6 +437,7 @@ class ModelTesterMixin: continue model = model_class(config) model.to(torch_device) + model.gradient_checkpointing_enable() model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) loss = model(**inputs).loss diff --git a/tests/test_modeling_deit.py b/tests/test_modeling_deit.py index 119e098891..222f2afbe2 100644 --- a/tests/test_modeling_deit.py +++ b/tests/test_modeling_deit.py @@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase): if model_class.__name__ == "DeiTForImageClassificationWithTeacher": continue model = model_class(config) + model.gradient_checkpointing_enable() model.to(torch_device) model.train() inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True) diff --git a/tests/test_modeling_unispeech_sat.py b/tests/test_modeling_unispeech_sat.py index c6e09f0a5a..56c429d2f9 100644 --- a/tests/test_modeling_unispeech_sat.py +++ b/tests/test_modeling_unispeech_sat.py @@ -66,6 +66,8 @@ class UniSpeechSatModelTester: layer_norm_eps=1e-5, hidden_act="gelu", initializer_range=0.02, + mask_time_prob=0.5, + mask_time_length=2, vocab_size=32, do_stable_layer_norm=False, scope=None, @@ -93,6 +95,8 @@ class UniSpeechSatModelTester: self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length self.scope = scope output_seq_length = self.seq_length @@ -121,6 +125,8 @@ class UniSpeechSatModelTester: conv_bias=self.conv_bias, num_conv_pos_embeddings=self.num_conv_pos_embeddings, num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, + mask_time_prob=self.mask_time_prob, + mask_time_length=self.mask_time_length, num_hidden_layers=self.num_hidden_layers, num_attention_heads=self.num_attention_heads, hidden_dropout_prob=self.hidden_dropout_prob, diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index c68acbeff1..5ac40fe98e 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -78,6 +78,8 @@ class Wav2Vec2ModelTester: layer_norm_eps=1e-5, hidden_act="gelu", initializer_range=0.02, + mask_time_prob=0.5, + mask_time_length=2, vocab_size=32, do_stable_layer_norm=False, scope=None, @@ -105,6 +107,8 @@ class Wav2Vec2ModelTester: self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm + self.mask_time_prob = mask_time_prob + self.mask_time_length = mask_time_length self.scope = scope output_seq_length = self.seq_length @@ -131,6 +135,8 @@ class Wav2Vec2ModelTester: conv_stride=self.conv_stride, conv_kernel=self.conv_kernel, conv_bias=self.conv_bias, + mask_time_prob=self.mask_time_prob, + mask_time_length=self.mask_time_length, num_conv_pos_embeddings=self.num_conv_pos_embeddings, num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups, num_hidden_layers=self.num_hidden_layers,