Tf longformer for sequence classification (#8231)

* working on LongformerForSequenceClassification

* add TFLongformerForMultipleChoice

* add TFLongformerForTokenClassification

* use add_start_docstrings_to_model_forward

* test TFLongformerForSequenceClassification

* test TFLongformerForMultipleChoice

* test TFLongformerForTokenClassification

* remove test from repo

* add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice

* add requested classes to modeling_tf_auto.py
update dummy_tf_objects
fix tests
fix bugs in requested classes

* pass all tests except test_inputs_embeds

* sync with master

* pass all tests except test_inputs_embeds

* pass all tests

* pass all tests

* work on test_inputs_embeds

* fix style and quality

* make multi choice work

* fix TFLongformerForTokenClassification signature

* fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature

* fix mult choice

* fix mc hint

* fix input embeds

* fix input embeds

* refactor input embeds

* fix copy issue

* apply sylvains changes and clean more

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
elk-cloner
2020-11-19 19:07:27 +03:30
committed by GitHub
parent 62cd9ce9f8
commit 5362bb8a6b
10 changed files with 892 additions and 71 deletions

View File

@@ -29,7 +29,10 @@ if is_tf_available():
from transformers import (
LongformerConfig,
TFLongformerForMaskedLM,
TFLongformerForMultipleChoice,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForTokenClassification,
TFLongformerModel,
TFLongformerSelfAttention,
)
@@ -130,7 +133,7 @@ class TFLongformerModelTester:
output_without_mask = model(input_ids)[0]
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
def create_and_check_longformer_model(
def create_and_check_model(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
@@ -144,7 +147,7 @@ class TFLongformerModelTester:
)
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_model_with_global_attention_mask(
def create_and_check_model_with_global_attention_mask(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
@@ -172,7 +175,7 @@ class TFLongformerModelTester:
)
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
def create_and_check_longformer_for_masked_lm(
def create_and_check_for_masked_lm(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
@@ -180,7 +183,7 @@ class TFLongformerModelTester:
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
def create_and_check_longformer_for_question_answering(
def create_and_check_for_question_answering(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.return_dict = True
@@ -196,6 +199,41 @@ class TFLongformerModelTester:
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
def create_and_check_for_sequence_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForSequenceClassification(config=config)
output = model(
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
def create_and_check_for_token_classification(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_labels = self.num_labels
model = TFLongformerForTokenClassification(config=config)
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
def create_and_check_for_multiple_choice(
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
):
config.num_choices = self.num_choices
model = TFLongformerForMultipleChoice(config=config)
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
output = model(
multiple_choice_inputs_ids,
attention_mask=multiple_choice_input_mask,
global_attention_mask=multiple_choice_input_mask,
token_type_ids=multiple_choice_token_type_ids,
labels=choice_labels,
).logits
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
(
@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
TFLongformerModel,
TFLongformerForMaskedLM,
TFLongformerForQuestionAnswering,
TFLongformerForSequenceClassification,
TFLongformerForMultipleChoice,
TFLongformerForTokenClassification,
)
if is_tf_available()
else ()
@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
def test_config(self):
self.config_tester.run_common_tests()
def test_longformer_model_attention_mask_determinism(self):
def test_model_attention_mask_determinism(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
def test_longformer_model(self):
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
self.model_tester.create_and_check_model(*config_and_inputs)
def test_longformer_model_global_attention_mask(self):
def test_model_global_attention_mask(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
def test_longformer_for_masked_lm(self):
def test_for_masked_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
def test_longformer_for_question_answering(self):
def test_for_question_answering(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
def test_for_sequence_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
def test_for_token_classification(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
def test_for_multiple_choice(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
@slow
def test_saved_model_with_attentions_output(self):