From 9cc9f4122e2a1027a6011951e3c6629a0f1b6c3e Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 11 Dec 2020 16:59:54 +0100 Subject: [PATCH] Make ProphetNetModel really compatible with EncoderDecoder (#9033) * improve * finish * upload model * fix lm head * fix test --- .../models/prophetnet/modeling_prophetnet.py | 34 +++++++++++++------ .../xlm_prophetnet/modeling_xlm_prophetnet.py | 4 +-- tests/test_modeling_encoder_decoder.py | 4 +-- tests/test_modeling_prophetnet.py | 30 ++++++++++++++++ utils/check_repo.py | 2 ++ 5 files changed, 59 insertions(+), 15 deletions(-) diff --git a/src/transformers/models/prophetnet/modeling_prophetnet.py b/src/transformers/models/prophetnet/modeling_prophetnet.py index b3af9c62b2..e8856830e4 100644 --- a/src/transformers/models/prophetnet/modeling_prophetnet.py +++ b/src/transformers/models/prophetnet/modeling_prophetnet.py @@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): config = copy.deepcopy(config) config.is_decoder = True config.is_encoder_decoder = False - self.decoder = ProphetNetDecoder(config) + self.prophetnet = ProphetNetDecoderWrapper(config) self.padding_idx = config.pad_token_id self.disable_ngram_loss = config.disable_ngram_loss @@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): self.init_weights() def get_input_embeddings(self): - return self.decoder.word_embeddings + return self.prophetnet.decoder.word_embeddings def set_input_embeddings(self, value): - self.decoder.word_embeddings = value + self.prophetnet.decoder.word_embeddings = value def get_output_embeddings(self): return self.lm_head @@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def set_decoder(self, decoder): + self.prophetnet.decoder = decoder + + def get_decoder(self): + return self.prophetnet.decoder + @add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC) def forward( @@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): >>> import torch >>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') - >>> model = ProphetNetForCausalLM.from_pretrained('patrickvonplaten/prophetnet-decoder-clm-large-uncased') + >>> model = ProphetNetForCausalLM.from_pretrained('microsoft/prophetnet-large-uncased') >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) @@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): >>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased') >>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased') - >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "microsoft/prophetnet-large-uncased") >>> ARTICLE = ( ... "the us state department said wednesday it had received no " @@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): return_dict = return_dict if return_dict is not None else self.config.use_return_dict # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) - outputs = self.decoder( + outputs = self.prophetnet.decoder( input_ids=input_ids, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, @@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel): reordered_past.append(layer_past_new) return reordered_past - def set_decoder(self, decoder): - self.decoder = decoder - def get_decoder(self): - return self.decoder +class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel): + """ + This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from + pretrained prophetnet classes. + """ + + def __init__(self, config): + super().__init__(config) + self.decoder = ProphetNetDecoder(config) + + def forward(self, *args, **kwargs): + return self.decoder(*args, **kwargs) diff --git a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py index 9240cea230..43266ae1a4 100644 --- a/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/modeling_xlm_prophetnet.py @@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM): >>> import torch >>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased') - >>> model = XLMProphetNetForCausalLM.from_pretrained('patrickvonplaten/xprophetnet-decoder-clm-large-uncased') + >>> model = XLMProphetNetForCausalLM.from_pretrained('microsoft/xprophetnet-large-wiki100-cased') >>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder." >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") >>> outputs = model(**inputs) @@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM): >>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large') >>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased') - >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased") + >>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", 'microsoft/xprophetnet-large-wiki100-cased') >>> ARTICLE = ( ... "the us state department said wednesday it had received no " diff --git a/tests/test_modeling_encoder_decoder.py b/tests/test_modeling_encoder_decoder.py index 42205dcf64..94dee583e8 100644 --- a/tests/test_modeling_encoder_decoder.py +++ b/tests/test_modeling_encoder_decoder.py @@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): } def get_pretrained_model(self): - return EncoderDecoderModel.from_encoder_decoder_pretrained( - "bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased" - ) + return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "prophetnet-large-uncased") def test_encoder_decoder_model_shared_weights(self): pass diff --git a/tests/test_modeling_prophetnet.py b/tests/test_modeling_prophetnet.py index 00249f2a06..a88b2653c1 100644 --- a/tests/test_modeling_prophetnet.py +++ b/tests/test_modeling_prophetnet.py @@ -38,6 +38,7 @@ if is_torch_available(): ProphetNetModel, ProphetNetTokenizer, ) + from transformers.modeling_outputs import BaseModelOutput class ProphetNetModelTester: @@ -467,6 +468,31 @@ class ProphetNetModelTester: ) ) + def check_causal_lm_from_pretrained( + self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args + ): + model = ProphetNetForConditionalGeneration(config).to(torch_device).eval() + + with tempfile.TemporaryDirectory() as tmp_dirname: + model.save_pretrained(tmp_dirname) + decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device) + + encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state + + model_outputs = model( + encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states), + decoder_input_ids=decoder_input_ids, + ) + dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids) + + self.parent.assertTrue( + torch.allclose( + model_outputs.logits[0, :5], + dec_outputs.logits[0, :5], + atol=1e-3, + ) + ) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() ( @@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test self.assertFalse(config.add_cross_attention) + def test_causal_lm_from_pretrained(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs) + @unittest.skipIf(torch_device == "cpu", "Cant do half precision") def test_fp16_forward(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() diff --git a/utils/check_repo.py b/utils/check_repo.py index a367134fd9..0710e8101c 100644 --- a/utils/check_repo.py +++ b/utils/check_repo.py @@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [ "BertLMHeadModel", # Needs to be setup as decoder. "DPREncoder", # Building part of bigger (tested) model. "DPRSpanPredictor", # Building part of bigger (tested) model. + "ProphetNetDecoderWrapper", # Building part of bigger (tested) model. "ReformerForMaskedLM", # Needs to be setup as decoder. "T5Stack", # Building part of bigger (tested) model. "TFDPREncoder", # Building part of bigger (tested) model. @@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [ "OpenAIGPTDoubleHeadsModel", "ProphetNetDecoder", "ProphetNetEncoder", + "ProphetNetDecoderWrapper", "RagModel", "RagSequenceForGeneration", "RagTokenForGeneration",