Make ProphetNetModel really compatible with EncoderDecoder (#9033)
* improve * finish * upload model * fix lm head * fix test
This commit is contained in:
committed by
GitHub
parent
24f6cdeab6
commit
9cc9f4122e
@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased"
|
||||
)
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "prophetnet-large-uncased")
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
pass
|
||||
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
ProphetNetModel,
|
||||
ProphetNetTokenizer,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
|
||||
class ProphetNetModelTester:
|
||||
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
|
||||
)
|
||||
)
|
||||
|
||||
def check_causal_lm_from_pretrained(
|
||||
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args
|
||||
):
|
||||
model = ProphetNetForConditionalGeneration(config).to(torch_device).eval()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||
model.save_pretrained(tmp_dirname)
|
||||
decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device)
|
||||
|
||||
encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state
|
||||
|
||||
model_outputs = model(
|
||||
encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states),
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
)
|
||||
dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids)
|
||||
|
||||
self.parent.assertTrue(
|
||||
torch.allclose(
|
||||
model_outputs.logits[0, :5],
|
||||
dec_outputs.logits[0, :5],
|
||||
atol=1e-3,
|
||||
)
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
||||
|
||||
self.assertFalse(config.add_cross_attention)
|
||||
|
||||
def test_causal_lm_from_pretrained(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs)
|
||||
|
||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||
def test_fp16_forward(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user