Fix TF s2s models (#9478)
* Fix Seq2Seq models for serving * Apply style * Fix lonfgormer * Fix mBart/Pegasus/Blenderbot * Apply style * Add a main intermediate layer * Apply style * Remove import * Apply tf.function to Longformer * Fix utils check_copy * Update S2S template * Fix BART + Blenderbot * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix BlenderbotSmall * Fix MBart * Fix Marian * Fix Pegasus + template * Apply style * Fix common attributes test * Forgot to fix the LED test * Apply Patrick's comment on LED Decoder
This commit is contained in:
@@ -188,28 +188,19 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
assert isinstance(model.get_input_embeddings(), tf.keras.layers.Layer)
|
||||
x = model.get_output_layer_with_bias()
|
||||
assert x is None
|
||||
name = model.get_prefix_bias_name()
|
||||
assert name is None
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||
pass
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||
pass
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||
pass
|
||||
|
||||
def test_saved_model_creation_extended(self):
|
||||
# TODO(JPLU, PVP) - fix this with s2s tf-serving PR
|
||||
pass
|
||||
if model_class in self.all_generative_model_classes:
|
||||
x = model.get_output_embeddings()
|
||||
assert isinstance(x, tf.keras.layers.Layer)
|
||||
name = model.get_bias()
|
||||
assert isinstance(name, dict)
|
||||
for k, v in name.items():
|
||||
assert isinstance(v, tf.Variable)
|
||||
else:
|
||||
x = model.get_output_embeddings()
|
||||
assert x is None
|
||||
name = model.get_bias()
|
||||
assert name is None
|
||||
|
||||
def test_resize_token_embeddings(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
@@ -274,6 +265,10 @@ class TFBlenderbotSmallModelTest(TFModelTesterMixin, unittest.TestCase):
|
||||
models_equal = False
|
||||
self.assertTrue(models_equal)
|
||||
|
||||
def test_saved_model_creation(self):
|
||||
# This test is too long (>30sec) and makes fail the CI
|
||||
pass
|
||||
|
||||
|
||||
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
|
||||
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
|
||||
|
||||
Reference in New Issue
Block a user