Tests for added AutoModels
This commit is contained in:
@@ -21,7 +21,11 @@ import shutil
|
|||||||
import pytest
|
import pytest
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from pytorch_transformers import AutoConfig, BertConfig, AutoModel, BertModel
|
from pytorch_transformers import (AutoConfig, BertConfig,
|
||||||
|
AutoModel, BertModel,
|
||||||
|
AutoModelWithLMHead, BertForMaskedLM,
|
||||||
|
AutoModelForSequenceClassification, BertForSequenceClassification,
|
||||||
|
AutoModelForQuestionAnswering, BertForQuestionAnswering)
|
||||||
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
from pytorch_transformers.modeling_bert import BERT_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||||
|
|
||||||
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
from .modeling_common_test import (CommonTestCases, ConfigTester, ids_tensor)
|
||||||
@@ -42,6 +46,42 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
for value in loading_info.values():
|
for value in loading_info.values():
|
||||||
self.assertEqual(len(value), 0)
|
self.assertEqual(len(value), 0)
|
||||||
|
|
||||||
|
def test_lmhead_model_from_pretrained(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = AutoModelWithLMHead.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelWithLMHead.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, BertForMaskedLM)
|
||||||
|
|
||||||
|
def test_sequence_classification_model_from_pretrained(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForSequenceClassification.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, BertForSequenceClassification)
|
||||||
|
|
||||||
|
def test_question_answering_model_from_pretrained(self):
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||||
|
config = AutoConfig.from_pretrained(model_name)
|
||||||
|
self.assertIsNotNone(config)
|
||||||
|
self.assertIsInstance(config, BertConfig)
|
||||||
|
|
||||||
|
model = AutoModelForQuestionAnswering.from_pretrained(model_name)
|
||||||
|
model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True)
|
||||||
|
self.assertIsNotNone(model)
|
||||||
|
self.assertIsInstance(model, BertForQuestionAnswering)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user