[Bert2Bert] allow bert2bert + relative embeddings (#14324)

* [Bert2Bert] allow bert2bert + relative embeddings

* up

* Update README_ko.md

* up

* up
This commit is contained in:
Patrick von Platen
2021-11-09 20:26:58 +01:00
committed by GitHub
parent e4d8f517b9
commit e81d8d7fa9
11 changed files with 70 additions and 40 deletions

View File

@@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
"labels": decoder_token_labels,
}
def test_relative_position_embeds(self):
config_and_inputs = self.prepare_config_and_inputs()
encoder_config = config_and_inputs["config"]
decoder_config = config_and_inputs["decoder_config"]
encoder_config.position_embedding_type = "relative_key_query"
decoder_config.position_embedding_type = "relative_key_query"
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
model = EncoderDecoderModel(config).eval().to(torch_device)
logits = model(
input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"]
).logits
self.assertTrue(logits.shape, (13, 7))
@slow
def test_bert2bert_summarization(self):
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")