[EncoderDecoder] Add encoder-decoder for roberta/ vanilla longformer (#6411)

* add encoder-decoder for roberta

* fix headmask

* apply Sylvains suggestions

* fix typo

* Apply suggestions from code review
This commit is contained in:
Patrick von Platen
2020-08-12 18:23:30 +02:00
committed by GitHub
parent fd3de2000f
commit 0735def8e1
10 changed files with 671 additions and 357 deletions

View File

@@ -152,7 +152,7 @@ class BertModelTester:
encoder_attention_mask,
)
def create_and_check_bert_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertModel(config=config)
@@ -164,7 +164,7 @@ class BertModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_bert_model_as_decoder(
def create_and_check_model_as_decoder(
self,
config,
input_ids,
@@ -197,7 +197,7 @@ class BertModelTester:
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_bert_for_causal_lm(
def create_and_check_for_causal_lm(
self,
config,
input_ids,
@@ -215,7 +215,7 @@ class BertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_bert_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForMaskedLM(config=config)
@@ -224,7 +224,7 @@ class BertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_bert_model_for_causal_lm_as_decoder(
def create_and_check_model_for_causal_lm_as_decoder(
self,
config,
input_ids,
@@ -257,7 +257,7 @@ class BertModelTester:
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
def create_and_check_bert_for_next_sequence_prediction(
def create_and_check_for_next_sequence_prediction(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForNextSentencePrediction(config=config)
@@ -268,7 +268,7 @@ class BertModelTester:
)
self.parent.assertEqual(result.logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_pretraining(
def create_and_check_for_pretraining(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForPreTraining(config=config)
@@ -284,7 +284,7 @@ class BertModelTester:
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
self.parent.assertEqual(result.seq_relationship_logits.shape, (self.batch_size, 2))
def create_and_check_bert_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
model = BertForQuestionAnswering(config=config)
@@ -300,7 +300,7 @@ class BertModelTester:
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
def create_and_check_bert_for_sequence_classification(
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@@ -310,7 +310,7 @@ class BertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
def create_and_check_bert_for_token_classification(
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
@@ -320,7 +320,7 @@ class BertModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
def create_and_check_bert_for_multiple_choice(
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
@@ -379,15 +379,15 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_bert_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_bert_model_as_decoder(self):
def test_model_as_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_model_as_decoder(*config_and_inputs)
self.model_tester.create_and_check_model_as_decoder(*config_and_inputs)
def test_bert_model_as_decoder_with_default_input_mask(self):
def test_model_as_decoder_with_default_input_mask(self):
# This regression test was failing with PyTorch < 1.3
(
config,
@@ -403,7 +403,7 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
input_mask = None
self.model_tester.create_and_check_bert_model_as_decoder(
self.model_tester.create_and_check_model_as_decoder(
config,
input_ids,
token_type_ids,
@@ -417,39 +417,39 @@ class BertModelTest(ModelTesterMixin, unittest.TestCase):
def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_for_causal_lm(*config_and_inputs)
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_for_causal_lm_decoder(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_decoder()
self.model_tester.create_and_check_bert_model_for_causal_lm_as_decoder(*config_and_inputs)
self.model_tester.create_and_check_model_for_causal_lm_as_decoder(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_multiple_choice(*config_and_inputs)
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
def test_for_next_sequence_prediction(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_next_sequence_prediction(*config_and_inputs)
self.model_tester.create_and_check_for_next_sequence_prediction(*config_and_inputs)
def test_for_pretraining(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_pretraining(*config_and_inputs)
self.model_tester.create_and_check_for_pretraining(*config_and_inputs)
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_sequence_classification(*config_and_inputs)
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):