From 3290315a2aae431a901d4f93aa8ebff518a96fe3 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Thu, 15 Jul 2021 15:06:12 +0200 Subject: [PATCH] Fix AutoModel tests (#12733) --- tests/test_modeling_auto.py | 7 +++++-- tests/test_modeling_common.py | 7 +++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/test_modeling_auto.py b/tests/test_modeling_auto.py index 0ba839c42a..65315e8cac 100644 --- a/tests/test_modeling_auto.py +++ b/tests/test_modeling_auto.py @@ -88,8 +88,11 @@ class AutoModelTest(unittest.TestCase): model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True) self.assertIsNotNone(model) self.assertIsInstance(model, BertModel) - for value in loading_info.values(): - self.assertEqual(len(value), 0) + + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertEqual(len(loading_info["unexpected_keys"]), 8) + self.assertEqual(len(loading_info["mismatched_keys"]), 0) + self.assertEqual(len(loading_info["error_msgs"]), 0) @slow def test_model_for_pretraining_from_pretrained(self): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 3d4c57c158..d109bbf6a5 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -1617,8 +1617,11 @@ class ModelUtilsTest(TestCasePlus): model, loading_info = BertModel.from_pretrained(model_name, output_loading_info=True) self.assertIsNotNone(model) self.assertIsInstance(model, PreTrainedModel) - for value in loading_info.values(): - self.assertEqual(len(value), 0) + + self.assertEqual(len(loading_info["missing_keys"]), 0) + self.assertEqual(len(loading_info["unexpected_keys"]), 8) + self.assertEqual(len(loading_info["mismatched_keys"]), 0) + self.assertEqual(len(loading_info["error_msgs"]), 0) config = BertConfig.from_pretrained(model_name, output_attentions=True, output_hidden_states=True)