Fix embeddings resizing in TF models (#8657)

* Resize the biases in same time than the embeddings

* Trigger CI

* Biases are not reset anymore

* Remove get_output_embeddings + better LM model detection in generation utils

* Apply style

* First test on BERT

* Update docstring + new name

* Apply the new resizing logic to all the models

* fix tests

* Apply style

* Update the template

* Fix naming

* Fix naming

* Apply style

* Apply style

* Remove unused import

* Revert get_output_embeddings

* Trigger CI

* Update num parameters

* Restore get_output_embeddings in TFPretrainedModel and add comments

* Style

* Add decoder resizing

* Style

* Fix tests

* Separate bias and decoder resize

* Fix tests

* Fix tests

* Apply style

* Add bias resizing in MPNet

* Trigger CI

* Apply style
This commit is contained in:
Julien Plu
2020-12-14 05:05:24 +01:00
committed by GitHub
parent 3552d0e0d8
commit 51d9c569fa
31 changed files with 470 additions and 18 deletions

View File

@@ -18,17 +18,17 @@ import unittest
from tests.test_configuration_common import ConfigTester
from tests.test_modeling_tf_bart import TFBartModelTester
from tests.test_modeling_tf_common import TFModelTesterMixin
from transformers import (
BlenderbotConfig,
BlenderbotSmallTokenizer,
TFAutoModelForSeq2SeqLM,
TFBlenderbotForConditionalGeneration,
is_tf_available,
)
from transformers import BlenderbotConfig, BlenderbotSmallTokenizer, is_tf_available
from transformers.file_utils import cached_property
from transformers.testing_utils import is_pt_tf_cross_test, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import TFAutoModelForSeq2SeqLM, TFBlenderbotForConditionalGeneration
class TFBlenderbotModelTester(TFBartModelTester):
config_updates = dict(
normalize_before=True,
@@ -65,6 +65,17 @@ class TFBlenderbotModelTest(TFModelTesterMixin, unittest.TestCase):
# Should be uncommented during patrick TF refactor
pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_layer_with_bias()
assert x is None
name = model.get_prefix_bias_name()
assert name is None
@is_pt_tf_cross_test
@require_tokenizers