TF: standardize test_model_common_attributes for language models (#23457)
This commit is contained in:
@@ -363,24 +363,6 @@ class TFGPTJModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_gptj_lm_head_model(*config_and_inputs)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
if model_class in self.all_generative_model_classes:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
@slow
|
||||
@unittest.skipIf(
|
||||
not is_tf_available() or len(tf.config.list_physical_devices("GPU")) > 0,
|
||||
|
||||
Reference in New Issue
Block a user