[TF models] Common attributes as per #1721

This commit is contained in:
Julien Chaumond
2019-11-11 16:30:22 -05:00
parent 872403be1c
commit 70d97ddd60
11 changed files with 84 additions and 0 deletions

View File

@@ -360,6 +360,16 @@ class TFCommonTestCases:
# self.assertTrue(models_equal)
def test_model_common_attributes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
x = model.get_output_embeddings()
assert x is None or instanceof(x, tf.keras.layers.Layer)
def test_tie_model_weights(self):
pass
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()