Update serving code to enable saved_model=True (#18153)
* Add serving_output and serving methods to some vision models * Add serving outputs for DeiT * Don't convert hidden states - differing shapes * Make saveable * Fix up * Make swin saveable * Add in tests * Fix funnel tests (can't convert to tensor) * Fix numpy call * Tidy up a bit * Add in hidden states - resnet * Remove numpy * Fix failing tests - tensor shape and skipping tests * Remove duplicated function * PR comments - formatting and var names * PR comments Add suggestions made by Joao Gante: * Use tf.shape instead of shape_list * Use @tooslow decorator on tests * Simplify some of the logic * PR comments Address Yih-Dar Sheih comments - making tensor names consistent and make types float * Types consistent with docs; disable test on swin (slow) * CI trigger * Change input_features to float32 * Add serving_output for segformer * Fixup Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
This commit is contained in:
@@ -201,27 +201,6 @@ class TFCoreModelTesterMixin:
|
||||
val_loss = history.history["val_loss"][0]
|
||||
self.assertTrue(not isnan(val_loss))
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
config.output_hidden_states = False
|
||||
config.output_attentions = False
|
||||
|
||||
if hasattr(config, "use_cache"):
|
||||
config.use_cache = False
|
||||
|
||||
model_class = self.all_model_classes[0]
|
||||
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
|
||||
model(class_inputs_dict)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=True)
|
||||
saved_model_dir = os.path.join(tmpdirname, "saved_model", "1")
|
||||
self.assertTrue(os.path.exists(saved_model_dir))
|
||||
|
||||
@slow
|
||||
def test_saved_model_creation_extended(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
Reference in New Issue
Block a user