[TF models] Common attributes as per #1721
This commit is contained in:
@@ -360,6 +360,16 @@ class TFCommonTestCases:
|
||||
# self.assertTrue(models_equal)
|
||||
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None or instanceof(x, tf.keras.layers.Layer)
|
||||
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
pass
|
||||
# config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user