Custom TF weights loading (#7422)
* First try * Fix TF utils * Handle authorized unexpected keys when loading weights * Add several more authorized unexpected keys * Apply style * Fix test * Address Patrick's comments. * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Apply style * Make return_dict the default behavior and display a warning message * Revert * Replace wrong keyword * Revert code * Add forgot key * Fix bug in loading PT models from a TF one. * Fix sort * Add a test for custom load weights in BERT * Apply style * Remove unused import Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import unittest
|
||||
|
||||
from transformers import BertConfig, is_tf_available
|
||||
from transformers.testing_utils import require_tf, slow
|
||||
from transformers.testing_utils import require_tf
|
||||
|
||||
from .test_configuration_common import ConfigTester
|
||||
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
|
||||
@@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
|
||||
|
||||
@slow
|
||||
def test_model_from_pretrained(self):
|
||||
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
|
||||
for model_name in ["bert-base-uncased"]:
|
||||
model = TFBertModel.from_pretrained(model_name)
|
||||
self.assertIsNotNone(model)
|
||||
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
def test_custom_load_tf_weights(self):
|
||||
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
|
||||
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
|
||||
)
|
||||
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
|
||||
for layer in output_loading_info["missing_keys"]:
|
||||
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])
|
||||
|
||||
Reference in New Issue
Block a user