Add AutoModelForPreTraining

This commit is contained in:
thomwolf
2020-01-24 17:49:02 -05:00
committed by Lysandre Debut
parent ea56d305be
commit 0e31e06a75
5 changed files with 373 additions and 1 deletions

View File

@@ -28,6 +28,8 @@ if is_tf_available():
BertConfig,
TFAutoModel,
TFBertModel,
TFAutoModelForPreTraining,
TFBertForPreTraining,
TFAutoModelWithLMHead,
TFBertForMaskedLM,
TFRobertaForMaskedLM,
@@ -57,6 +59,23 @@ class TFAutoModelTest(unittest.TestCase):
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertModel)
@slow
def test_model_for_pretraining_from_pretrained(self):
import h5py
self.assertTrue(h5py.version.hdf5_version.startswith("1.10"))
logging.basicConfig(level=logging.INFO)
# for model_name in list(TF_BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
for model_name in ["bert-base-uncased"]:
config = AutoConfig.from_pretrained(model_name)
self.assertIsNotNone(config)
self.assertIsInstance(config, BertConfig)
model = TFAutoModelForPreTraining.from_pretrained(model_name)
self.assertIsNotNone(model)
self.assertIsInstance(model, TFBertForPreTraining)
@slow
def test_lmhead_model_from_pretrained(self):
logging.basicConfig(level=logging.INFO)