From 94e55253aef2ccb4b0de95e4aadd6432e3e6a65a Mon Sep 17 00:00:00 2001 From: Stefan Schweter Date: Mon, 11 Nov 2019 16:20:15 +0100 Subject: [PATCH] tests: add test case for DistilBertForTokenClassification implementation --- .../tests/modeling_distilbert_test.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/transformers/tests/modeling_distilbert_test.py b/transformers/tests/modeling_distilbert_test.py index 937d03396d..8099c03586 100644 --- a/transformers/tests/modeling_distilbert_test.py +++ b/transformers/tests/modeling_distilbert_test.py @@ -23,6 +23,7 @@ from transformers import is_torch_available if is_torch_available(): from transformers import (DistilBertConfig, DistilBertModel, DistilBertForMaskedLM, + DistilBertForTokenClassification, DistilBertForQuestionAnswering, DistilBertForSequenceClassification) else: pytestmark = pytest.mark.skip("Require Torch") @@ -180,6 +181,21 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): [self.batch_size, self.num_labels]) self.check_loss_output(result) + def create_and_check_distilbert_for_token_classification(self, config, input_ids, input_mask, sequence_labels, token_labels, choice_labels): + config.num_labels = self.num_labels + model = DistilBertForTokenClassification(config=config) + model.eval() + + loss, logits = model(input_ids, attention_mask=input_mask, labels=token_labels) + result = { + "loss": loss, + "logits": logits, + } + self.parent.assertListEqual( + list(result["logits"].size()), + [self.batch_size, self.seq_length, self.num_labels]) + self.check_loss_output(result) + def prepare_config_and_inputs_for_common(self): config_and_inputs = self.prepare_config_and_inputs() (config, input_ids, input_mask, sequence_labels, token_labels, choice_labels) = config_and_inputs @@ -209,6 +225,10 @@ class DistilBertModelTest(CommonTestCases.CommonModelTester): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_distilbert_for_sequence_classification(*config_and_inputs) + def test_for_token_classification(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_distilbert_for_token_classification(*config_and_inputs) + # @pytest.mark.slow # def test_model_from_pretrained(self): # cache_dir = "/tmp/transformers_test/"