From 25e83894393bc0e2d42913465fef024cd1e9a9e4 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Mon, 26 Aug 2019 16:02:23 -0400 Subject: [PATCH] Tests for added AutoModels --- .../tests/modeling_auto_test.py | 42 ++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/tests/modeling_auto_test.py b/pytorch_transformers/tests/modeling_auto_test.py index d0c830abc7..09d09b28fc 100644 --- a/pytorch_transformers/tests/modeling_auto_test.py +++ b/pytorch_transformers/tests/modeling_auto_test.py @@ -21,7 +21,11 @@ import shutil import pytest import logging -from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel +from pytorch_transformers import (AutoConfig, BertConfig, + AutoModel, BertModel, + AutoModelWithLMHead, BertForMaskedLM, + AutoModelForSequenceClassification, BertForSequenceClassification, + AutoModelForQuestionAnswering, BertForQuestionAnswering) from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor) @@ -42,6 +46,42 @@ class AutoModelTest(unittest.TestCase): for value in loading_info.values(): self.assertEqual(len(value), 0) + def test_lmhead_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = AutoModelWithLMHead.from_pretrained(model_name) + model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, BertForMaskedLM) + + def test_sequence_classification_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = AutoModelForSequenceClassification.from_pretrained(model_name) + model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, BertForSequenceClassification) + + def test_question_answering_model_from_pretrained(self): + logging.basicConfig(level=logging.INFO) + for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) + + model = AutoModelForQuestionAnswering.from_pretrained(model_name) + model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True) + self.assertIsNotNone(model) + self.assertIsInstance(model, BertForQuestionAnswering) + if __name__ == "__main__": unittest.main()