Make ProphetNetModel really compatible with EncoderDecoder (#9033)

* improve

* finish

* upload model

* fix lm head

* fix test
This commit is contained in:
Patrick von Platen
2020-12-11 16:59:54 +01:00
committed by GitHub
parent 24f6cdeab6
commit 9cc9f4122e
5 changed files with 59 additions and 15 deletions

View File

@@ -38,6 +38,7 @@ if is_torch_available():
ProphetNetModel,
ProphetNetTokenizer,
)
from transformers.modeling_outputs import BaseModelOutput
class ProphetNetModelTester:
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
)
)
def check_causal_lm_from_pretrained(
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args
):
model = ProphetNetForConditionalGeneration(config).to(torch_device).eval()
with tempfile.TemporaryDirectory() as tmp_dirname:
model.save_pretrained(tmp_dirname)
decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device)
encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state
model_outputs = model(
encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states),
decoder_input_ids=decoder_input_ids,
)
dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids)
self.parent.assertTrue(
torch.allclose(
model_outputs.logits[0, :5],
dec_outputs.logits[0, :5],
atol=1e-3,
)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
self.assertFalse(config.add_cross_attention)
def test_causal_lm_from_pretrained(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_fp16_forward(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()