Fix gradient_checkpointing backward compatibility (#14408)

* Fix gradient_checkpointing backward compatibility

* Remove needless line

* make sure mask prob is big enough and length small enough

* Fix tests

Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-11-16 08:58:42 -05:00
committed by GitHub
parent 1cc453d33c
commit 040fd47162
8 changed files with 47 additions and 6 deletions

View File

@@ -265,6 +265,7 @@ class BeitModelTest(ModelTesterMixin, unittest.TestCase):
[self.model_tester.batch_size, height, width], device=torch_device
).long()
model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)

View File

@@ -213,6 +213,25 @@ class ModelTesterMixin:
)
self.assertTrue(len(load_result.unexpected_keys) == 0)
def test_gradient_checkpointing_backward_compatibility(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if not model_class.supports_gradient_checkpointing:
continue
config.gradient_checkpointing = True
model = model_class(config)
# Model does not have gradient checkpointing activated yet, it will be done at the first forward.
self.assertFalse(model.is_gradient_checkpointing)
model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
_ = model(**inputs)
# Model has gradient checkpointing activated after the first forward.
self.assertTrue(model.is_gradient_checkpointing)
def test_gradient_checkpointing_enable_disable(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
@@ -418,6 +437,7 @@ class ModelTesterMixin:
continue
model = model_class(config)
model.to(torch_device)
model.gradient_checkpointing_enable()
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
loss = model(**inputs).loss

View File

@@ -367,6 +367,7 @@ class DeiTModelTest(ModelTesterMixin, unittest.TestCase):
if model_class.__name__ == "DeiTForImageClassificationWithTeacher":
continue
model = model_class(config)
model.gradient_checkpointing_enable()
model.to(torch_device)
model.train()
inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)

View File

@@ -66,6 +66,8 @@ class UniSpeechSatModelTester:
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
scope=None,
@@ -93,6 +95,8 @@ class UniSpeechSatModelTester:
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
output_seq_length = self.seq_length
@@ -121,6 +125,8 @@ class UniSpeechSatModelTester:
conv_bias=self.conv_bias,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
hidden_dropout_prob=self.hidden_dropout_prob,

View File

@@ -78,6 +78,8 @@ class Wav2Vec2ModelTester:
layer_norm_eps=1e-5,
hidden_act="gelu",
initializer_range=0.02,
mask_time_prob=0.5,
mask_time_length=2,
vocab_size=32,
do_stable_layer_norm=False,
scope=None,
@@ -105,6 +107,8 @@ class Wav2Vec2ModelTester:
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.scope = scope
output_seq_length = self.seq_length
@@ -131,6 +135,8 @@ class Wav2Vec2ModelTester:
conv_stride=self.conv_stride,
conv_kernel=self.conv_kernel,
conv_bias=self.conv_bias,
mask_time_prob=self.mask_time_prob,
mask_time_length=self.mask_time_length,
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
num_hidden_layers=self.num_hidden_layers,