[Fix] position_ids tests again (#6100)
This commit is contained in:
@@ -568,6 +568,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
|||||||
config_class = BertConfig
|
config_class = BertConfig
|
||||||
load_tf_weights = load_tf_weights_in_bert
|
load_tf_weights = load_tf_weights_in_bert
|
||||||
base_model_prefix = "bert"
|
base_model_prefix = "bert"
|
||||||
|
authorized_missing_keys = [r"position_ids"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
""" Initialize the weights """
|
""" Initialize the weights """
|
||||||
@@ -699,8 +700,6 @@ class BertModel(BertPreTrainedModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
authorized_missing_keys = [r"position_ids"]
|
|
||||||
|
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super().__init__(config)
|
super().__init__(config)
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|||||||
@@ -88,9 +88,11 @@ class AutoModelTest(unittest.TestCase):
|
|||||||
model, loading_info = AutoModelForPreTraining.from_pretrained(model_name, output_loading_info=True)
|
model, loading_info = AutoModelForPreTraining.from_pretrained(model_name, output_loading_info=True)
|
||||||
self.assertIsNotNone(model)
|
self.assertIsNotNone(model)
|
||||||
self.assertIsInstance(model, BertForPreTraining)
|
self.assertIsInstance(model, BertForPreTraining)
|
||||||
|
# Only one value should not be initialized and in the missing keys.
|
||||||
|
missing_keys = loading_info.pop("missing_keys")
|
||||||
|
self.assertListEqual(["cls.predictions.decoder.bias"], missing_keys)
|
||||||
for key, value in loading_info.items():
|
for key, value in loading_info.items():
|
||||||
# Only one value should not be initialized and in the missing keys.
|
self.assertEqual(len(value), 0)
|
||||||
self.assertEqual(len(value), 1 if key == "missing_keys" else 0)
|
|
||||||
|
|
||||||
@slow
|
@slow
|
||||||
def test_lmhead_model_from_pretrained(self):
|
def test_lmhead_model_from_pretrained(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user