[Fix] position_ids tests again (#6100)
This commit is contained in:
@@ -88,9 +88,11 @@ class AutoModelTest(unittest.TestCase):
|
||||
model, loading_info = AutoModelForPreTraining.from_pretrained(model_name, output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForPreTraining)
|
||||
# Only one value should not be initialized and in the missing keys.
|
||||
missing_keys = loading_info.pop("missing_keys")
|
||||
self.assertListEqual(["cls.predictions.decoder.bias"], missing_keys)
|
||||
for key, value in loading_info.items():
|
||||
# Only one value should not be initialized and in the missing keys.
|
||||
self.assertEqual(len(value), 1 if key == "missing_keys" else 0)
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
@slow
|
||||
def test_lmhead_model_from_pretrained(self):
|
||||
|
||||
Reference in New Issue
Block a user