[VisionEncoderDecoder] Add gradient checkpointing (#18697)

* add first generation tutorial

* VisionEnocderDecoder gradient checkpointing

* remove generation

* add tests
This commit is contained in:
Patrick von Platen
2022-08-26 14:11:27 +02:00
committed by GitHub
parent 06a6a4bd51
commit 8869bf41fe
3 changed files with 52 additions and 0 deletions

View File

@@ -396,6 +396,28 @@ class EncoderDecoderMixin:
input_ids_dict = self.prepare_config_and_inputs()
self.check_encoder_decoder_model_generate(**input_ids_dict)
def test_training_gradient_checkpointing(self):
inputs_dict = self.prepare_config_and_inputs()
encoder_model, decoder_model = self.get_encoder_decoder_model(
inputs_dict["config"], inputs_dict["decoder_config"]
)
model = SpeechEncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
model.train()
model.gradient_checkpointing_enable()
model.config.decoder_start_token_id = 0
model.config.pad_token_id = 0
model_inputs = {
"attention_mask": inputs_dict["attention_mask"],
"labels": inputs_dict["labels"],
"decoder_input_ids": inputs_dict["decoder_input_ids"],
}
inputs = inputs_dict["input_features"] if "input_features" in inputs_dict else inputs_dict["input_values"]
loss = model(inputs, **model_inputs).loss
loss.backward()
@slow
def test_real_model_save_load_from_pretrained(self):
model_2, inputs = self.get_pretrained_model_and_inputs()
@@ -590,6 +612,7 @@ class Wav2Vec2Speech2Text2(EncoderDecoderMixin, unittest.TestCase):
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_attention_mask": decoder_attention_mask,
"labels": decoder_input_ids,
}
# there are no published pretrained Speech2Text2ForCausalLM for now