Custom TF weights loading (#7422)

* First try

* Fix TF utils

* Handle authorized unexpected keys when loading weights

* Add several more authorized unexpected keys

* Apply style

* Fix test

* Address Patrick's comments.

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_tf_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply style

* Make return_dict the default behavior and display a warning message

* Revert

* Replace wrong keyword

* Revert code

* Add forgot key

* Fix bug in loading PT models from a TF one.

* Fix sort

* Add a test for custom load weights in BERT

* Apply style

* Remove unused import

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
Julien Plu
2020-10-05 15:58:45 +02:00
committed by GitHub
parent d3adb985d1
commit 9cf7b23b9b
4 changed files with 139 additions and 27 deletions

View File

@@ -17,7 +17,7 @@
import unittest
from transformers import BertConfig, is_tf_available
from transformers.testing_utils import require_tf, slow
from transformers.testing_utils import require_tf
from .test_configuration_common import ConfigTester
from .test_modeling_tf_common import TFModelTesterMixin, ids_tensor
@@ -317,9 +317,14 @@ class TFBertModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_bert_for_token_classification(*config_and_inputs)
@slow
def test_model_from_pretrained(self):
# for model_name in TF_BERT_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
for model_name in ["bert-base-uncased"]:
model = TFBertModel.from_pretrained(model_name)
self.assertIsNotNone(model)
model = TFBertModel.from_pretrained("jplu/tiny-tf-bert-random")
self.assertIsNotNone(model)
def test_custom_load_tf_weights(self):
model, output_loading_info = TFBertForTokenClassification.from_pretrained(
"jplu/tiny-tf-bert-random", use_cdn=False, output_loading_info=True
)
self.assertEqual(sorted(output_loading_info["unexpected_keys"]), ["mlm___cls", "nsp___cls"])
for layer in output_loading_info["missing_keys"]:
self.assertTrue(layer.split("_")[0] in ["dropout", "classifier"])