Make ProphetNetModel really compatible with EncoderDecoder (#9033)
* improve * finish * upload model * fix lm head * fix test
This commit is contained in:
committed by
GitHub
parent
24f6cdeab6
commit
9cc9f4122e
@@ -1886,7 +1886,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
config = copy.deepcopy(config)
|
config = copy.deepcopy(config)
|
||||||
config.is_decoder = True
|
config.is_decoder = True
|
||||||
config.is_encoder_decoder = False
|
config.is_encoder_decoder = False
|
||||||
self.decoder = ProphetNetDecoder(config)
|
self.prophetnet = ProphetNetDecoderWrapper(config)
|
||||||
|
|
||||||
self.padding_idx = config.pad_token_id
|
self.padding_idx = config.pad_token_id
|
||||||
self.disable_ngram_loss = config.disable_ngram_loss
|
self.disable_ngram_loss = config.disable_ngram_loss
|
||||||
@@ -1896,10 +1896,10 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def get_input_embeddings(self):
|
def get_input_embeddings(self):
|
||||||
return self.decoder.word_embeddings
|
return self.prophetnet.decoder.word_embeddings
|
||||||
|
|
||||||
def set_input_embeddings(self, value):
|
def set_input_embeddings(self, value):
|
||||||
self.decoder.word_embeddings = value
|
self.prophetnet.decoder.word_embeddings = value
|
||||||
|
|
||||||
def get_output_embeddings(self):
|
def get_output_embeddings(self):
|
||||||
return self.lm_head
|
return self.lm_head
|
||||||
@@ -1907,6 +1907,12 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
def set_output_embeddings(self, new_embeddings):
|
def set_output_embeddings(self, new_embeddings):
|
||||||
self.lm_head = new_embeddings
|
self.lm_head = new_embeddings
|
||||||
|
|
||||||
|
def set_decoder(self, decoder):
|
||||||
|
self.prophetnet.decoder = decoder
|
||||||
|
|
||||||
|
def get_decoder(self):
|
||||||
|
return self.prophetnet.decoder
|
||||||
|
|
||||||
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
|
@add_start_docstrings_to_model_forward(PROPHETNET_STANDALONE_INPUTS_DOCSTRING)
|
||||||
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
|
@replace_return_docstrings(output_type=ProphetNetDecoderLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||||
def forward(
|
def forward(
|
||||||
@@ -1956,7 +1962,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
>>> tokenizer = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
||||||
>>> model = ProphetNetForCausalLM.from_pretrained('patrickvonplaten/prophetnet-decoder-clm-large-uncased')
|
>>> model = ProphetNetForCausalLM.from_pretrained('microsoft/prophetnet-large-uncased')
|
||||||
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
@@ -1969,7 +1975,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
|
|
||||||
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
|
>>> tokenizer_enc = BertTokenizer.from_pretrained('bert-large-uncased')
|
||||||
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
>>> tokenizer_dec = ProphetNetTokenizer.from_pretrained('microsoft/prophetnet-large-uncased')
|
||||||
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased")
|
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "microsoft/prophetnet-large-uncased")
|
||||||
|
|
||||||
>>> ARTICLE = (
|
>>> ARTICLE = (
|
||||||
... "the us state department said wednesday it had received no "
|
... "the us state department said wednesday it had received no "
|
||||||
@@ -1985,7 +1991,7 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||||
outputs = self.decoder(
|
outputs = self.prophetnet.decoder(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
attention_mask=attention_mask,
|
attention_mask=attention_mask,
|
||||||
encoder_hidden_states=encoder_hidden_states,
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
@@ -2086,8 +2092,16 @@ class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
|||||||
reordered_past.append(layer_past_new)
|
reordered_past.append(layer_past_new)
|
||||||
return reordered_past
|
return reordered_past
|
||||||
|
|
||||||
def set_decoder(self, decoder):
|
|
||||||
self.decoder = decoder
|
|
||||||
|
|
||||||
def get_decoder(self):
|
class ProphetNetDecoderWrapper(ProphetNetPreTrainedModel):
|
||||||
return self.decoder
|
"""
|
||||||
|
This is a wrapper class, so that :class:`~transformers.ProphetNetForCausalLM` can correctly be loaded from
|
||||||
|
pretrained prophetnet classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__(config)
|
||||||
|
self.decoder = ProphetNetDecoder(config)
|
||||||
|
|
||||||
|
def forward(self, *args, **kwargs):
|
||||||
|
return self.decoder(*args, **kwargs)
|
||||||
|
|||||||
@@ -136,7 +136,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
|
|||||||
>>> import torch
|
>>> import torch
|
||||||
|
|
||||||
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
|
>>> tokenizer = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
|
||||||
>>> model = XLMProphetNetForCausalLM.from_pretrained('patrickvonplaten/xprophetnet-decoder-clm-large-uncased')
|
>>> model = XLMProphetNetForCausalLM.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
|
||||||
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
>>> assert model.config.is_decoder, f"{model.__class__} has to be configured as a decoder."
|
||||||
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
|
||||||
>>> outputs = model(**inputs)
|
>>> outputs = model(**inputs)
|
||||||
@@ -149,7 +149,7 @@ class XLMProphetNetForCausalLM(ProphetNetForCausalLM):
|
|||||||
|
|
||||||
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
>>> tokenizer_enc = XLMRobertaTokenizer.from_pretrained('xlm-roberta-large')
|
||||||
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
|
>>> tokenizer_dec = XLMProphetNetTokenizer.from_pretrained('microsoft/xprophetnet-large-wiki100-cased')
|
||||||
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", "patrickvonplaten/xprophetnet-decoder-clm-large-uncased")
|
>>> model = EncoderDecoderModel.from_encoder_decoder_pretrained("xlm-roberta-large", 'microsoft/xprophetnet-large-wiki100-cased')
|
||||||
|
|
||||||
>>> ARTICLE = (
|
>>> ARTICLE = (
|
||||||
... "the us state department said wednesday it had received no "
|
... "the us state department said wednesday it had received no "
|
||||||
|
|||||||
@@ -802,9 +802,7 @@ class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def get_pretrained_model(self):
|
def get_pretrained_model(self):
|
||||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-large-uncased", "prophetnet-large-uncased")
|
||||||
"bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_encoder_decoder_model_shared_weights(self):
|
def test_encoder_decoder_model_shared_weights(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
|||||||
ProphetNetModel,
|
ProphetNetModel,
|
||||||
ProphetNetTokenizer,
|
ProphetNetTokenizer,
|
||||||
)
|
)
|
||||||
|
from transformers.modeling_outputs import BaseModelOutput
|
||||||
|
|
||||||
|
|
||||||
class ProphetNetModelTester:
|
class ProphetNetModelTester:
|
||||||
@@ -467,6 +468,31 @@ class ProphetNetModelTester:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def check_causal_lm_from_pretrained(
|
||||||
|
self, config, input_ids, decoder_input_ids, attention_mask, decoder_attention_mask, *args
|
||||||
|
):
|
||||||
|
model = ProphetNetForConditionalGeneration(config).to(torch_device).eval()
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dirname:
|
||||||
|
model.save_pretrained(tmp_dirname)
|
||||||
|
decoder = ProphetNetForCausalLM.from_pretrained(tmp_dirname).to(torch_device)
|
||||||
|
|
||||||
|
encoder_hidden_states = model.prophetnet.encoder(input_ids).last_hidden_state
|
||||||
|
|
||||||
|
model_outputs = model(
|
||||||
|
encoder_outputs=BaseModelOutput(last_hidden_state=encoder_hidden_states),
|
||||||
|
decoder_input_ids=decoder_input_ids,
|
||||||
|
)
|
||||||
|
dec_outputs = decoder(encoder_hidden_states=encoder_hidden_states, input_ids=decoder_input_ids)
|
||||||
|
|
||||||
|
self.parent.assertTrue(
|
||||||
|
torch.allclose(
|
||||||
|
model_outputs.logits[0, :5],
|
||||||
|
dec_outputs.logits[0, :5],
|
||||||
|
atol=1e-3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
(
|
(
|
||||||
@@ -898,6 +924,10 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
|
|||||||
|
|
||||||
self.assertFalse(config.add_cross_attention)
|
self.assertFalse(config.add_cross_attention)
|
||||||
|
|
||||||
|
def test_causal_lm_from_pretrained(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.check_causal_lm_from_pretrained(*config_and_inputs)
|
||||||
|
|
||||||
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
|
||||||
def test_fp16_forward(self):
|
def test_fp16_forward(self):
|
||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ IGNORE_NON_TESTED = [
|
|||||||
"BertLMHeadModel", # Needs to be setup as decoder.
|
"BertLMHeadModel", # Needs to be setup as decoder.
|
||||||
"DPREncoder", # Building part of bigger (tested) model.
|
"DPREncoder", # Building part of bigger (tested) model.
|
||||||
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
"DPRSpanPredictor", # Building part of bigger (tested) model.
|
||||||
|
"ProphetNetDecoderWrapper", # Building part of bigger (tested) model.
|
||||||
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
"ReformerForMaskedLM", # Needs to be setup as decoder.
|
||||||
"T5Stack", # Building part of bigger (tested) model.
|
"T5Stack", # Building part of bigger (tested) model.
|
||||||
"TFDPREncoder", # Building part of bigger (tested) model.
|
"TFDPREncoder", # Building part of bigger (tested) model.
|
||||||
@@ -74,6 +75,7 @@ IGNORE_NON_AUTO_CONFIGURED = [
|
|||||||
"OpenAIGPTDoubleHeadsModel",
|
"OpenAIGPTDoubleHeadsModel",
|
||||||
"ProphetNetDecoder",
|
"ProphetNetDecoder",
|
||||||
"ProphetNetEncoder",
|
"ProphetNetEncoder",
|
||||||
|
"ProphetNetDecoderWrapper",
|
||||||
"RagModel",
|
"RagModel",
|
||||||
"RagSequenceForGeneration",
|
"RagSequenceForGeneration",
|
||||||
"RagTokenForGeneration",
|
"RagTokenForGeneration",
|
||||||
|
|||||||
Reference in New Issue
Block a user