Add support for gradient checkpointing (#19990)
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
@@ -581,6 +581,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = BertGenerationConfig
|
||||
base_model_prefix = "bert"
|
||||
supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
@@ -599,6 +600,10 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, BertEncoder):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
BERT_GENERATION_START_DOCSTRING = r"""
|
||||
|
||||
|
||||
@@ -175,6 +175,8 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
"""
|
||||
config_class = EncoderDecoderConfig
|
||||
base_model_prefix = "encoder_decoder"
|
||||
main_input_name = "input_ids"
|
||||
supports_gradient_checkpointing = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -255,6 +257,11 @@ class EncoderDecoderModel(PreTrainedModel):
|
||||
self.encoder, self.decoder._modules[decoder_base_model_prefix], self.decoder.base_model_prefix
|
||||
)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
# call both encoder and decoder function on gradient checkpointing
|
||||
self.encoder._set_gradient_checkpointing(module, value=value)
|
||||
self.decoder._set_gradient_checkpointing(module, value=value)
|
||||
|
||||
def get_encoder(self):
|
||||
return self.encoder
|
||||
|
||||
|
||||
@@ -611,6 +611,27 @@ class EncoderDecoderMixin:
|
||||
input_ids_dict = self.prepare_config_and_inputs()
|
||||
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
|
||||
|
||||
def test_training_gradient_checkpointing(self):
|
||||
inputs_dict = self.prepare_config_and_inputs()
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(
|
||||
inputs_dict["config"], inputs_dict["decoder_config"]
|
||||
)
|
||||
|
||||
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
model.train()
|
||||
model.gradient_checkpointing_enable()
|
||||
model.config.decoder_start_token_id = 0
|
||||
model.config.pad_token_id = 0
|
||||
|
||||
model_inputs = {
|
||||
"input_ids": inputs_dict["input_ids"],
|
||||
"attention_mask": inputs_dict["attention_mask"],
|
||||
"labels": inputs_dict["labels"],
|
||||
"decoder_input_ids": inputs_dict["decoder_input_ids"],
|
||||
}
|
||||
loss = model(**model_inputs).loss
|
||||
loss.backward()
|
||||
|
||||
@slow
|
||||
def test_real_model_save_load_from_pretrained(self):
|
||||
model_2 = self.get_pretrained_model()
|
||||
|
||||
Reference in New Issue
Block a user