[EncoderDecoderModel] add a add_cross_attention boolean to config (#6377)

* correct encoder decoder model

* Apply suggestions from code review

* apply sylvains suggestions
This commit is contained in:
Patrick von Platen
2020-08-10 19:46:48 +02:00
committed by GitHub
parent 06bc347c97
commit 3425936643
6 changed files with 27 additions and 8 deletions

View File

@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
return {
"config": config,
"input_ids": input_ids,
@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
decoder_model = BertLMHeadModel(decoder_config)
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
enc_dec_model.to(torch_device)
outputs_encoder_decoder = enc_dec_model(
@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
self.assertIsNotNone(model)
@slow
def test_real_bert_model_from_pretrained_has_cross_attention(self):
def test_real_bert_model_from_pretrained_add_cross_attention(self):
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))