Add AutoModelForPreTraining
This commit is contained in:
@@ -28,6 +28,8 @@ if is_torch_available():
|
||||
BertConfig,
|
||||
AutoModel,
|
||||
BertModel,
|
||||
AutoModelForPreTraining,
|
||||
BertForPreTraining,
|
||||
AutoModelWithLMHead,
|
||||
BertForMaskedLM,
|
||||
RobertaForMaskedLM,
|
||||
@@ -56,6 +58,21 @@ class AutoModelTest(unittest.TestCase):
|
||||
for value in loading_info.values():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
@slow
|
||||
def test_model_for_pretraining_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
self.assertIsNotNone(config)
|
||||
self.assertIsInstance(config, BertConfig)
|
||||
|
||||
model = AutoModelForPreTraining.from_pretrained(model_name)
|
||||
model, loading_info = AutoModelForPreTraining.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForPreTraining)
|
||||
for value in loading_info.values():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
@slow
|
||||
def test_lmhead_model_from_pretrained(self):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
Reference in New Issue
Block a user