@@ -262,6 +262,24 @@ class TFEsmModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
def test_save_load_after_resize_token_embeddings(self):
|
||||
pass
|
||||
|
||||
def test_model_common_attributes(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
if model_class is TFEsmForMaskedLM:
|
||||
# Output embedding test differs from the main test because they're a matrix, not a layer
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
|
||||
@require_tf
|
||||
class TFEsmModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user