Add support for gradient checkpointing (#19990)

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
This commit is contained in:
NielsRogge
2022-10-31 18:37:17 +01:00
committed by GitHub
parent 8214a9f66a
commit 4c9e0f029e
3 changed files with 33 additions and 0 deletions

View File

@@ -611,6 +611,27 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.create_and_check_encoder_decoder_shared_weights(**input_ids_dict)
def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)
model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
model_inputs = {
"input_ids": inputs_dict["input_ids"],
"attention_mask": inputs_dict["attention_mask"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
loss = model(**model_inputs).loss
loss.backward()
@slow
def test_real_model_save_load_from_pretrained(self):
model_2 = self.get_pretrained_model()