[Bert2Bert] allow bert2bert + relative embeddings (#14324)
* [Bert2Bert] allow bert2bert + relative embeddings * up * Update README_ko.md * up * up
This commit is contained in:
committed by
GitHub
parent
e4d8f517b9
commit
e81d8d7fa9
@@ -567,6 +567,24 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
"labels": decoder_token_labels,
|
||||
}
|
||||
|
||||
def test_relative_position_embeds(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
encoder_config = config_and_inputs["config"]
|
||||
decoder_config = config_and_inputs["decoder_config"]
|
||||
|
||||
encoder_config.position_embedding_type = "relative_key_query"
|
||||
decoder_config.position_embedding_type = "relative_key_query"
|
||||
|
||||
config = EncoderDecoderConfig.from_encoder_decoder_configs(encoder_config, decoder_config)
|
||||
model = EncoderDecoderModel(config).eval().to(torch_device)
|
||||
|
||||
logits = model(
|
||||
input_ids=config_and_inputs["input_ids"], decoder_input_ids=config_and_inputs["decoder_input_ids"]
|
||||
).logits
|
||||
|
||||
self.assertTrue(logits.shape, (13, 7))
|
||||
|
||||
@slow
|
||||
def test_bert2bert_summarization(self):
|
||||
model = EncoderDecoderModel.from_pretrained("patrickvonplaten/bert2bert-cnn_dailymail-fp16")
|
||||
|
||||
Reference in New Issue
Block a user