TF: standardize test_model_common_attributes for language models (#23457)
This commit is contained in:
@@ -1013,7 +1013,7 @@ class TFModelTesterMixin:
|
||||
check_hidden_states_output(config, inputs_dict, model_class)
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
text_in_text_out_models = (
|
||||
get_values(TF_MODEL_FOR_CAUSAL_LM_MAPPING)
|
||||
+ get_values(TF_MODEL_FOR_MASKED_LM_MAPPING)
|
||||
@@ -1023,24 +1023,27 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
if model_class in text_in_text_out_models:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
self.assertIsInstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
|
||||
legacy_text_in_text_out = model.get_lm_head() is not None
|
||||
if model_class in text_in_text_out_models or legacy_text_in_text_out:
|
||||
out_embeddings = model.get_output_embeddings()
|
||||
self.assertIsInstance(out_embeddings, tf.keras.layers.Layer)
|
||||
bias = model.get_bias()
|
||||
if bias is not None:
|
||||
self.assertIsInstance(bias, dict)
|
||||
for _, v in bias.items():
|
||||
self.assertIsInstance(v, tf.Variable)
|
||||
elif model_class in speech_in_text_out_models:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
out_embeddings = model.get_output_embeddings()
|
||||
self.assertIsInstance(out_embeddings, tf.keras.layers.Layer)
|
||||
bias = model.get_bias()
|
||||
self.assertIsNone(bias)
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
out_embeddings = model.get_output_embeddings()
|
||||
assert out_embeddings is None
|
||||
bias = model.get_bias()
|
||||
self.assertIsNone(bias)
|
||||
|
||||
def test_determinism(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user