Tf longformer for sequence classification (#8231)
* working on LongformerForSequenceClassification * add TFLongformerForMultipleChoice * add TFLongformerForTokenClassification * use add_start_docstrings_to_model_forward * test TFLongformerForSequenceClassification * test TFLongformerForMultipleChoice * test TFLongformerForTokenClassification * remove test from repo * add test and doc for TFLongformerForSequenceClassification, TFLongformerForTokenClassification, TFLongformerForMultipleChoice * add requested classes to modeling_tf_auto.py update dummy_tf_objects fix tests fix bugs in requested classes * pass all tests except test_inputs_embeds * sync with master * pass all tests except test_inputs_embeds * pass all tests * pass all tests * work on test_inputs_embeds * fix style and quality * make multi choice work * fix TFLongformerForTokenClassification signature * fix TFLongformerForMultipleChoice, TFLongformerForSequenceClassification signature * fix mult choice * fix mc hint * fix input embeds * fix input embeds * refactor input embeds * fix copy issue * apply sylvains changes and clean more Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -129,7 +129,7 @@ class LongformerModelTester:
|
||||
output_without_mask = model(input_ids)["last_hidden_state"]
|
||||
self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
@@ -141,7 +141,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
def create_and_check_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerModel(config=config)
|
||||
@@ -163,7 +163,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
|
||||
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerForMaskedLM(config=config)
|
||||
@@ -172,7 +172,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
model = LongformerForQuestionAnswering(config=config)
|
||||
@@ -189,7 +189,7 @@ class LongformerModelTester:
|
||||
self.parent.assertEqual(result.start_logits.shape, (self.batch_size, self.seq_length))
|
||||
self.parent.assertEqual(result.end_logits.shape, (self.batch_size, self.seq_length))
|
||||
|
||||
def create_and_check_longformer_for_sequence_classification(
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
@@ -199,7 +199,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.num_labels))
|
||||
|
||||
def create_and_check_longformer_for_token_classification(
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
@@ -209,7 +209,7 @@ class LongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertEqual(result.logits.shape, (self.batch_size, self.seq_length, self.num_labels))
|
||||
|
||||
def create_and_check_longformer_for_multiple_choice(
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
@@ -296,37 +296,37 @@ class LongformerModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_longformer_model(self):
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_attention_mask_determinism(self):
|
||||
def test_model_attention_mask_determinism(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_global_attention_mask(self):
|
||||
def test_model_global_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_masked_lm(self):
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_question_answering(self):
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_sequence_classification(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_token_classification(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_multiple_choice(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
|
||||
@require_torch
|
||||
@@ -691,7 +691,7 @@ class LongformerModelIntegrationTest(unittest.TestCase):
|
||||
) # long input
|
||||
input_ids = input_ids.to(torch_device)
|
||||
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids)
|
||||
loss, prediction_scores = model(input_ids, labels=input_ids).to_tuple()
|
||||
|
||||
expected_loss = torch.tensor(0.0074, device=torch_device)
|
||||
expected_prediction_scores_sum = torch.tensor(-6.1048e08, device=torch_device)
|
||||
|
||||
@@ -29,7 +29,10 @@ if is_tf_available():
|
||||
from transformers import (
|
||||
LongformerConfig,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForTokenClassification,
|
||||
TFLongformerModel,
|
||||
TFLongformerSelfAttention,
|
||||
)
|
||||
@@ -130,7 +133,7 @@ class TFLongformerModelTester:
|
||||
output_without_mask = model(input_ids)[0]
|
||||
tf.debugging.assert_near(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], rtol=1e-4)
|
||||
|
||||
def create_and_check_longformer_model(
|
||||
def create_and_check_model(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -144,7 +147,7 @@ class TFLongformerModelTester:
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_model_with_global_attention_mask(
|
||||
def create_and_check_model_with_global_attention_mask(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -172,7 +175,7 @@ class TFLongformerModelTester:
|
||||
)
|
||||
self.parent.assertListEqual(shape_list(result.pooler_output), [self.batch_size, self.hidden_size])
|
||||
|
||||
def create_and_check_longformer_for_masked_lm(
|
||||
def create_and_check_for_masked_lm(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -180,7 +183,7 @@ class TFLongformerModelTester:
|
||||
result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels)
|
||||
self.parent.assertListEqual(shape_list(result.logits), [self.batch_size, self.seq_length, self.vocab_size])
|
||||
|
||||
def create_and_check_longformer_for_question_answering(
|
||||
def create_and_check_for_question_answering(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.return_dict = True
|
||||
@@ -196,6 +199,41 @@ class TFLongformerModelTester:
|
||||
self.parent.assertListEqual(shape_list(result.start_logits), [self.batch_size, self.seq_length])
|
||||
self.parent.assertListEqual(shape_list(result.end_logits), [self.batch_size, self.seq_length])
|
||||
|
||||
def create_and_check_for_sequence_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFLongformerForSequenceClassification(config=config)
|
||||
output = model(
|
||||
input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels
|
||||
).logits
|
||||
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.num_labels])
|
||||
|
||||
def create_and_check_for_token_classification(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_labels = self.num_labels
|
||||
model = TFLongformerForTokenClassification(config=config)
|
||||
output = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=token_labels).logits
|
||||
self.parent.assertListEqual(shape_list(output), [self.batch_size, self.seq_length, self.num_labels])
|
||||
|
||||
def create_and_check_for_multiple_choice(
|
||||
self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
|
||||
):
|
||||
config.num_choices = self.num_choices
|
||||
model = TFLongformerForMultipleChoice(config=config)
|
||||
multiple_choice_inputs_ids = tf.tile(tf.expand_dims(input_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_token_type_ids = tf.tile(tf.expand_dims(token_type_ids, 1), (1, self.num_choices, 1))
|
||||
multiple_choice_input_mask = tf.tile(tf.expand_dims(input_mask, 1), (1, self.num_choices, 1))
|
||||
output = model(
|
||||
multiple_choice_inputs_ids,
|
||||
attention_mask=multiple_choice_input_mask,
|
||||
global_attention_mask=multiple_choice_input_mask,
|
||||
token_type_ids=multiple_choice_token_type_ids,
|
||||
labels=choice_labels,
|
||||
).logits
|
||||
self.parent.assertListEqual(list(output.shape), [self.batch_size, self.num_choices])
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
(
|
||||
@@ -252,6 +290,9 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
TFLongformerModel,
|
||||
TFLongformerForMaskedLM,
|
||||
TFLongformerForQuestionAnswering,
|
||||
TFLongformerForSequenceClassification,
|
||||
TFLongformerForMultipleChoice,
|
||||
TFLongformerForTokenClassification,
|
||||
)
|
||||
if is_tf_available()
|
||||
else ()
|
||||
@@ -264,25 +305,37 @@ class TFLongformerModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
def test_config(self):
|
||||
self.config_tester.run_common_tests()
|
||||
|
||||
def test_longformer_model_attention_mask_determinism(self):
|
||||
def test_model_attention_mask_determinism(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_attention_mask_determinism(*config_and_inputs)
|
||||
|
||||
def test_longformer_model(self):
|
||||
def test_model(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_longformer_model_global_attention_mask(self):
|
||||
def test_model_global_attention_mask(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_model_with_global_attention_mask(*config_and_inputs)
|
||||
self.model_tester.create_and_check_model_with_global_attention_mask(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_masked_lm(self):
|
||||
def test_for_masked_lm(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_longformer_for_masked_lm(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_masked_lm(*config_and_inputs)
|
||||
|
||||
def test_longformer_for_question_answering(self):
|
||||
def test_for_question_answering(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs_for_question_answering()
|
||||
self.model_tester.create_and_check_longformer_for_question_answering(*config_and_inputs)
|
||||
self.model_tester.create_and_check_for_question_answering(*config_and_inputs)
|
||||
|
||||
def test_for_sequence_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_sequence_classification(*config_and_inputs)
|
||||
|
||||
def test_for_token_classification(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_token_classification(*config_and_inputs)
|
||||
|
||||
def test_for_multiple_choice(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_for_multiple_choice(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
|
||||
Reference in New Issue
Block a user