From 83f45cd6564e8940b5df902df5b0d140584ce9be Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Fri, 18 Feb 2022 08:50:23 -0500 Subject: [PATCH] Fix auto (#15706) --- tests/test_modeling_tf_auto.py | 38 +++++++++++++--------------------- 1 file changed, 14 insertions(+), 24 deletions(-) diff --git a/tests/test_modeling_tf_auto.py b/tests/test_modeling_tf_auto.py index dca08cf487..6987ba4779 100644 --- a/tests/test_modeling_tf_auto.py +++ b/tests/test_modeling_tf_auto.py @@ -85,35 +85,25 @@ if is_tf_available(): class TFAutoModelTest(unittest.TestCase): @slow def test_model_from_pretrained(self): - import h5py + model_name = "bert-base-cased" + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) - self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) - - # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - for model_name in ["bert-base-uncased"]: - config = AutoConfig.from_pretrained(model_name) - self.assertIsNotNone(config) - self.assertIsInstance(config, BertConfig) - - model = TFAutoModel.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertIsInstance(model, TFBertModel) + model = TFAutoModel.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertModel) @slow def test_model_for_pretraining_from_pretrained(self): - import h5py + model_name = "bert-base-cased" + config = AutoConfig.from_pretrained(model_name) + self.assertIsNotNone(config) + self.assertIsInstance(config, BertConfig) - self.assertTrue(h5py.version.hdf5_version.startswith("1.10")) - - # for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: - for model_name in ["bert-base-uncased"]: - config = AutoConfig.from_pretrained(model_name) - self.assertIsNotNone(config) - self.assertIsInstance(config, BertConfig) - - model = TFAutoModelForPreTraining.from_pretrained(model_name) - self.assertIsNotNone(model) - self.assertIsInstance(model, TFBertForPreTraining) + model = TFAutoModelForPreTraining.from_pretrained(model_name) + self.assertIsNotNone(model) + self.assertIsInstance(model, TFBertForPreTraining) @slow def test_model_for_causal_lm(self):