Fix gradient_checkpointing backward compatibility (#14408)
* Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -412,6 +412,17 @@ class ModuleUtilsMixin:
|
|||||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||||
|
|
||||||
|
|
||||||
|
def gradient_checkpointing_hook(module, _):
|
||||||
|
# Hook to enable backward compatibility for gradient checkpointing. Will be removed once all models have a
|
||||||
|
# proper post_init method.
|
||||||
|
if getattr(module.config, "gradient_checkpointing", False):
|
||||||
|
module.gradient_checkpointing_enable()
|
||||||
|
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
||||||
|
delattr(module.config, "gradient_checkpointing")
|
||||||
|
# The hook will remove itself after the first execution
|
||||||
|
module._gradient_checkpointing_hook.remove()
|
||||||
|
|
||||||
|
|
||||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
|
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
|
||||||
r"""
|
r"""
|
||||||
Base class for all models.
|
Base class for all models.
|
||||||
@@ -479,10 +490,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
|||||||
# Save config and origin of the pretrained weights if given in model
|
# Save config and origin of the pretrained weights if given in model
|
||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
if getattr(self.config, "gradient_checkpointing", False):
|
if self.supports_gradient_checkpointing:
|
||||||
self.gradient_checkpointing_enable()
|
self._gradient_checkpointing_hook = self.register_forward_pre_hook(gradient_checkpointing_hook)
|
||||||
# Remove the attribute now that is has been consumed, so it's no saved in the config.
|
|
||||||
delattr(self.config, "gradient_checkpointing")
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _from_config(cls, config, **kwargs):
|
def _from_config(cls, config, **kwargs):
|
||||||
|
|||||||
@@ -783,7 +783,6 @@ class DetrClassificationHead(nn.Module):
|
|||||||
class DetrPreTrainedModel(PreTrainedModel):
|
class DetrPreTrainedModel(PreTrainedModel):
|
||||||
config_class = DetrConfig
|
config_class = DetrConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
std = self.config.init_std
|
std = self.config.init_std
|
||||||
|
|||||||
@@ -504,7 +504,6 @@ class LayoutLMv2PreTrainedModel(PreTrainedModel):
|
|||||||
config_class = LayoutLMv2Config
|
config_class = LayoutLMv2Config
|
||||||
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
|
pretrained_model_archive_map = LAYOUTLMV2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
base_model_prefix = "layoutlmv2"
|
base_model_prefix = "layoutlmv2"
|
||||||
supports_gradient_checkpointing = True
|
|
||||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|||||||
@@ -239,6 +239,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if model_class.__name__ == "BeitForMaskedImageModeling":
|
if model_class.__name__ == "BeitForMaskedImageModeling":
|
||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|||||||
@@ -209,6 +209,25 @@ class ModelTesterMixin:
|
|||||||
)
|
)
|
||||||
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
self.assertTrue(len(load_result.unexpected_keys) == 0)
|
||||||
|
|
||||||
|
def test_gradient_checkpointing_backward_compatibility(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
if not model_class.supports_gradient_checkpointing:
|
||||||
|
continue
|
||||||
|
|
||||||
|
config.gradient_checkpointing = True
|
||||||
|
model = model_class(config)
|
||||||
|
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
|
||||||
|
self.assertFalse(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
|
model.to(torch_device)
|
||||||
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
_ = model(**inputs)
|
||||||
|
|
||||||
|
# Model has gradient checkpointing activated after the first forward.
|
||||||
|
self.assertTrue(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
def test_gradient_checkpointing_enable_disable(self):
|
def test_gradient_checkpointing_enable_disable(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
@@ -413,6 +432,7 @@ class ModelTesterMixin:
|
|||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
loss = model(**inputs).loss
|
loss = model(**inputs).loss
|
||||||
|
|||||||
@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
|
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
|
||||||
continue
|
continue
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
|
model.gradient_checkpointing_enable()
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.train()
|
model.train()
|
||||||
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
|
||||||
|
|||||||
@@ -65,6 +65,8 @@ class UniSpeechSatModelTester:
|
|||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-5,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
mask_time_prob=0.5,
|
||||||
|
mask_time_length=2,
|
||||||
vocab_size=32,
|
vocab_size=32,
|
||||||
do_stable_layer_norm=False,
|
do_stable_layer_norm=False,
|
||||||
scope=None,
|
scope=None,
|
||||||
@@ -92,6 +94,8 @@ class UniSpeechSatModelTester:
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
self.mask_time_prob = mask_time_prob
|
||||||
|
self.mask_time_length = mask_time_length
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
output_seq_length = self.seq_length
|
output_seq_length = self.seq_length
|
||||||
@@ -120,6 +124,8 @@ class UniSpeechSatModelTester:
|
|||||||
conv_bias=self.conv_bias,
|
conv_bias=self.conv_bias,
|
||||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||||
|
mask_time_prob=self.mask_time_prob,
|
||||||
|
mask_time_length=self.mask_time_length,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
num_attention_heads=self.num_attention_heads,
|
num_attention_heads=self.num_attention_heads,
|
||||||
hidden_dropout_prob=self.hidden_dropout_prob,
|
hidden_dropout_prob=self.hidden_dropout_prob,
|
||||||
|
|||||||
@@ -78,6 +78,8 @@ class Wav2Vec2ModelTester:
|
|||||||
layer_norm_eps=1e-5,
|
layer_norm_eps=1e-5,
|
||||||
hidden_act="gelu",
|
hidden_act="gelu",
|
||||||
initializer_range=0.02,
|
initializer_range=0.02,
|
||||||
|
mask_time_prob=0.5,
|
||||||
|
mask_time_length=2,
|
||||||
vocab_size=32,
|
vocab_size=32,
|
||||||
do_stable_layer_norm=False,
|
do_stable_layer_norm=False,
|
||||||
scope=None,
|
scope=None,
|
||||||
@@ -105,6 +107,8 @@ class Wav2Vec2ModelTester:
|
|||||||
self.initializer_range = initializer_range
|
self.initializer_range = initializer_range
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.do_stable_layer_norm = do_stable_layer_norm
|
self.do_stable_layer_norm = do_stable_layer_norm
|
||||||
|
self.mask_time_prob = mask_time_prob
|
||||||
|
self.mask_time_length = mask_time_length
|
||||||
self.scope = scope
|
self.scope = scope
|
||||||
|
|
||||||
output_seq_length = self.seq_length
|
output_seq_length = self.seq_length
|
||||||
@@ -131,6 +135,8 @@ class Wav2Vec2ModelTester:
|
|||||||
conv_stride=self.conv_stride,
|
conv_stride=self.conv_stride,
|
||||||
conv_kernel=self.conv_kernel,
|
conv_kernel=self.conv_kernel,
|
||||||
conv_bias=self.conv_bias,
|
conv_bias=self.conv_bias,
|
||||||
|
mask_time_prob=self.mask_time_prob,
|
||||||
|
mask_time_length=self.mask_time_length,
|
||||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||||
num_hidden_layers=self.num_hidden_layers,
|
num_hidden_layers=self.num_hidden_layers,
|
||||||
|
|||||||
Reference in New Issue
Block a user