From 9d5c49546fedb3d52f76f97f4043a07e08ded918 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Thu, 7 Nov 2019 17:32:52 +0000 Subject: [PATCH] Tests for AlbertForQuestionAnswering AlbertForSequenceClassification --- transformers/tests/modeling_albert_test.py | 45 +++++++++++++++++++++- 1 file changed, 44 insertions(+), 1 deletion(-) diff --git a/transformers/tests/modeling_albert_test.py b/transformers/tests/modeling_albert_test.py index 466f473332..da87709df1 100644 --- a/transformers/tests/modeling_albert_test.py +++ b/transformers/tests/modeling_albert_test.py @@ -26,7 +26,9 @@ from .modeling_common_test import (CommonTestCases, ids_tensor) from .configuration_common_test import ConfigTester if is_torch_available(): - from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM) + from transformers import (AlbertConfig, AlbertModel, AlbertForMaskedLM, + AlbertForSequenceClassification, AlbertForQuestionAnswering, + ) from transformers.modeling_albert import ALBERT_PRETRAINED_MODEL_ARCHIVE_MAP else: pytestmark = pytest.mark.skip("Require Torch") @@ -157,6 +159,39 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.seq_length, self.vocab_size]) self.check_loss_output(result) + def create_and_check_albert_for_question_answering(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + model = AlbertForQuestionAnswering(config=config) + model.eval() + loss, start_logits, end_logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, + start_positions=sequence_labels, end_positions=sequence_labels) + result = { + "loss": loss, + "start_logits": start_logits, + "end_logits": end_logits, + } + self.parent.assertListEqual( + list(result["start_logits"].size()), + [self.batch_size, self.seq_length]) + self.parent.assertListEqual( + list(result["end_logits"].size()), + [self.batch_size, self.seq_length]) + self.check_loss_output(result) + + + def create_and_check_albert_for_sequence_classification(self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): + config.num_labels = self.num_labels + model = AlbertForSequenceClassification(config) + model.eval() + loss, logits = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids, labels=sequence_labels) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual( + list(result["logits"].size()), + [self.batch_size, self.num_labels]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() @@ -180,6 +215,14 @@ class AlbertModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_albert_for_masked_lm(*config_and_inputs) + def test_for_question_answering(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_albert_for_question_answering(*config_and_inputs) + + def test_for_sequence_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_albert_for_sequence_classification(*config_and_inputs) + @pytest.mark.slow def test_model_from_pretrained(self): cache_dir = "/tmp/transformers_test/"