[EncoderDecoderModel] add a add_cross_attention boolean to config (#6377)
* correct encoder decoder model * Apply suggestions from code review * apply sylvains suggestions
This commit is contained in:
committed by
GitHub
parent
06bc347c97
commit
3425936643
@@ -59,6 +59,9 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
@@ -119,6 +122,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
decoder_model = BertLMHeadModel(decoder_config)
|
||||
enc_dec_model = EncoderDecoderModel(encoder=encoder_model, decoder=decoder_model)
|
||||
self.assertTrue(enc_dec_model.config.decoder.is_decoder)
|
||||
self.assertTrue(enc_dec_model.config.decoder.add_cross_attention)
|
||||
self.assertTrue(enc_dec_model.config.is_encoder_decoder)
|
||||
enc_dec_model.to(torch_device)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
@@ -330,7 +334,7 @@ class EncoderDecoderModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
@slow
|
||||
def test_real_bert_model_from_pretrained_has_cross_attention(self):
|
||||
def test_real_bert_model_from_pretrained_add_cross_attention(self):
|
||||
model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-uncased", "bert-base-uncased")
|
||||
self.assertTrue(hasattr(model.decoder.bert.encoder.layer[0], "crossattention"))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user