Fix saved model creation (#5468)
* Fix TF Serving when output_hidden_states and output_attentions are True * Add tests for saved model creation + bug fix for multiple choices models * remove unused import * Fix the input for several layers * Fix test * Fix conflict printing * Apply style * Fix XLM and Flaubert for TensorFlow * Apply style * Fix TF check version * Apply style * Trigger CI
This commit is contained in:
@@ -23,7 +23,7 @@ import unittest
|
||||
from importlib import import_module
|
||||
|
||||
from transformers import is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf
|
||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||
|
||||
|
||||
if is_tf_available():
|
||||
@@ -130,6 +130,61 @@ class TFModelTesterMixin:
|
||||
|
||||
self.assert_outputs_same(after_outputs, outputs)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_hidden_states_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = True
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
model._set_save_spec(inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||
self.assertListEqual(
|
||||
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||
)
|
||||
|
||||
@slow
|
||||
def test_saved_model_with_attentions_output(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_attentions = True
|
||||
encoder_seq_length = (
|
||||
self.model_tester.encoder_seq_length
|
||||
if hasattr(self.model_tester, "encoder_seq_length")
|
||||
else self.model_tester.seq_length
|
||||
)
|
||||
encoder_key_length = (
|
||||
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
||||
)
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
num_out = len(model(inputs_dict))
|
||||
model._saved_model_inputs_spec = None
|
||||
model._set_save_spec(inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
tf.saved_model.save(model, tmpdirname)
|
||||
model = tf.keras.models.load_model(tmpdirname)
|
||||
outputs = model(inputs_dict)
|
||||
attentions = [t.numpy() for t in outputs[-1]]
|
||||
self.assertEqual(len(outputs), num_out)
|
||||
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||
self.assertListEqual(
|
||||
list(attentions[0].shape[-3:]),
|
||||
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||
)
|
||||
|
||||
def test_keras_save_load(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
||||
@@ -342,11 +342,17 @@ class TFXLNetModelTester:
|
||||
"attention_mask": multiple_choice_input_mask,
|
||||
"token_type_ids": multiple_choice_token_type_ids,
|
||||
}
|
||||
(logits,) = model(inputs)
|
||||
(logits, mems_1) = model(inputs)
|
||||
result = {
|
||||
"mems_1": [mem.numpy() for mem in mems_1],
|
||||
"logits": logits.numpy(),
|
||||
}
|
||||
|
||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||
self.parent.assertListEqual(
|
||||
list(list(mem.shape) for mem in result["mems_1"]),
|
||||
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
|
||||
)
|
||||
|
||||
def prepare_config_and_inputs_for_common(self):
|
||||
config_and_inputs = self.prepare_config_and_inputs()
|
||||
|
||||
Reference in New Issue
Block a user