Fix weight tying in TF-ESM (#22839)

Fix weight tying in ESM
This commit is contained in:
Matt
2023-04-20 15:50:31 +01:00
committed by GitHub
parent 3b61d2890d
commit 6dc0a849b7
2 changed files with 35 additions and 8 deletions

View File

@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
def test_save_load_after_resize_token_embeddings(self):
pass
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
if model_class is TFEsmForMaskedLM:
# Output embedding test differs from the main test because they're a matrix, not a layer
name = model.get_bias()
assert isinstance(name, dict)
for k, v in name.items():
assert isinstance(v, tf.Variable)
else:
x = model.get_output_embeddings()
assert x is None
name = model.get_bias()
assert name is None
@require_tf
class TFEsmModelIntegrationTest(unittest.TestCase):