ProphetNet (#7157)
* add new model prophetnet prophetnet modified modify codes as suggested v1 add prophetnet test files * still bugs, because of changed output formats of encoder and decoder * move prophetnet into the latest version * clean integration tests * clean tokenizers * add xlm config to init * correct typo in init * further refactoring * continue refactor * save parallel * add decoder_attention_mask * fix use_cache vs. past_key_values * fix common tests * change decoder output logits * fix xlm tests * make common tests pass * change model architecture * add tokenizer tests * finalize model structure * no weight mapping * correct n-gram stream attention mask as discussed with qweizhen * remove unused import * fix index.rst * fix tests * delete unnecessary code * add fast integration test * rename weights * final weight remapping * save intermediate * Descriptions for Prophetnet Config File * finish all models * finish new model outputs * delete unnecessary files * refactor encoder layer * add dummy docs * code quality * fix tests * add model pages to doctree * further refactor * more refactor, more tests * finish code refactor and tests * remove unnecessary files * further clean up * add docstring template * finish tokenizer doc * finish prophetnet * fix copies * fix typos * fix tf tests * fix fp16 * fix tf test 2nd try * fix code quality * add test for each model * merge new tests to branch * Update model_cards/microsoft/prophetnet-large-uncased-cnndm/README.md Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update model_cards/microsoft/prophetnet-large-uncased-cnndm/README.md Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update src/transformers/modeling_prophetnet.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * Update utils/check_repo.py Co-authored-by: Sam Shleifer <sshleifer@gmail.com> * apply sams and sylvains comments * make style * remove unnecessary code * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update README.md Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/configuration_prophetnet.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * implement lysandres comments * correct docs * fix isort * fix tokenizers * fix copies Co-authored-by: weizhen <weizhen@mail.ustc.edu.cn> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sam Shleifer <sshleifer@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
@@ -24,6 +24,7 @@ from .test_modeling_bert import BertModelTester
|
||||
from .test_modeling_bert_generation import BertGenerationEncoderTester
|
||||
from .test_modeling_common import ids_tensor
|
||||
from .test_modeling_gpt2 import GPT2ModelTester
|
||||
from .test_modeling_prophetnet import ProphetNetStandaloneDecoderModelTester
|
||||
from .test_modeling_roberta import RobertaModelTester
|
||||
|
||||
|
||||
@@ -41,9 +42,11 @@ if is_torch_available():
|
||||
EncoderDecoderConfig,
|
||||
EncoderDecoderModel,
|
||||
GPT2LMHeadModel,
|
||||
ProphetNetForCausalLM,
|
||||
RobertaForCausalLM,
|
||||
RobertaModel,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutput
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -82,10 +85,15 @@ class EncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model(
|
||||
self,
|
||||
@@ -109,20 +117,30 @@ class EncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
encoder_outputs = (encoder_hidden_states,)
|
||||
encoder_outputs = BaseModelOutput(last_hidden_state=encoder_hidden_states)
|
||||
outputs_encoder_decoder = enc_dec_model(
|
||||
encoder_outputs=encoder_outputs,
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_from_pretrained(
|
||||
self,
|
||||
@@ -145,10 +163,15 @@ class EncoderDecoderMixin:
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[0].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_save_and_load(
|
||||
self,
|
||||
@@ -255,14 +278,19 @@ class EncoderDecoderMixin:
|
||||
attention_mask=attention_mask,
|
||||
decoder_attention_mask=decoder_attention_mask,
|
||||
labels=labels,
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
mlm_loss = outputs_encoder_decoder[0]
|
||||
loss = outputs_encoder_decoder["loss"]
|
||||
# check that backprop works
|
||||
mlm_loss.backward()
|
||||
loss.backward()
|
||||
|
||||
self.assertEqual(outputs_encoder_decoder[1].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)))
|
||||
self.assertEqual(outputs_encoder_decoder[2].shape, (input_ids.shape + (config.hidden_size,)))
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
|
||||
)
|
||||
self.assertEqual(
|
||||
outputs_encoder_decoder["encoder_last_hidden_state"].shape, (input_ids.shape + (config.hidden_size,))
|
||||
)
|
||||
|
||||
def check_encoder_decoder_model_generate(self, input_ids, config, decoder_config, **kwargs):
|
||||
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
|
||||
@@ -425,6 +453,7 @@ class EncoderDecoderMixin:
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "bert-base-cased")
|
||||
@@ -493,6 +522,7 @@ class BertEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
@require_torch
|
||||
class BertGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
@@ -554,6 +584,7 @@ class BertGenerationEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCa
|
||||
self.assertEqual(summary, EXPECTED_SUMMARY)
|
||||
|
||||
|
||||
@require_torch
|
||||
class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = RobertaModel(config)
|
||||
@@ -606,6 +637,7 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
|
||||
|
||||
|
||||
@require_torch
|
||||
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = BertModel(config)
|
||||
@@ -663,3 +695,59 @@ class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class ProphetNetEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
|
||||
def get_encoder_decoder_model(self, config, decoder_config):
|
||||
encoder_model = BertModel(config)
|
||||
decoder_model = ProphetNetForCausalLM(decoder_config)
|
||||
return encoder_model, decoder_model
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
model_tester_encoder = BertModelTester(self, batch_size=13)
|
||||
model_tester_decoder = ProphetNetStandaloneDecoderModelTester(
|
||||
self, batch_size=13, hidden_size=32, max_position_embeddings=512
|
||||
)
|
||||
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
|
||||
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
|
||||
(
|
||||
config,
|
||||
input_ids,
|
||||
token_type_ids,
|
||||
input_mask,
|
||||
sequence_labels,
|
||||
token_labels,
|
||||
choice_labels,
|
||||
) = encoder_config_and_inputs
|
||||
(
|
||||
decoder_config,
|
||||
decoder_input_ids,
|
||||
decoder_attention_mask,
|
||||
encoder_hidden_states,
|
||||
encoder_attention_mask,
|
||||
lm_labels,
|
||||
) = decoder_config_and_inputs
|
||||
|
||||
# make sure that cross attention layers are added
|
||||
decoder_config.add_cross_attention = True
|
||||
# disable cache for now
|
||||
decoder_config.use_cache = False
|
||||
return {
|
||||
"config": config,
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": input_mask,
|
||||
"decoder_config": decoder_config,
|
||||
"decoder_input_ids": decoder_input_ids,
|
||||
"decoder_attention_mask": decoder_attention_mask,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
"labels": lm_labels,
|
||||
}
|
||||
|
||||
def get_pretrained_model(self):
|
||||
return EncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||
"bert-large-uncased", "patrickvonplaten/prophetnet-decoder-clm-large-uncased"
|
||||
)
|
||||
|
||||
def test_encoder_decoder_model_shared_weights(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user