More TF fixes (#28081)
* More build_in_name_scope() * Make sure we set the save spec now we don't do it with dummies anymore * make fixup
This commit is contained in:
@@ -1147,6 +1147,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
|||||||
self.config = config
|
self.config = config
|
||||||
self.name_or_path = config.name_or_path
|
self.name_or_path = config.name_or_path
|
||||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||||
|
self._set_save_spec(self.input_signature)
|
||||||
|
|
||||||
def get_config(self):
|
def get_config(self):
|
||||||
return self.config.to_dict()
|
return self.config.to_dict()
|
||||||
|
|||||||
@@ -211,7 +211,7 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
config = copy.deepcopy(model.config)
|
config = copy.deepcopy(model.config)
|
||||||
config.architectures = ["FunnelBaseModel"]
|
config.architectures = ["FunnelBaseModel"]
|
||||||
model = TFAutoModel.from_config(config)
|
model = TFAutoModel.from_config(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||||
|
|
||||||
@@ -249,7 +249,7 @@ class TFAutoModelTest(unittest.TestCase):
|
|||||||
config = NewModelConfig(**tiny_config.to_dict())
|
config = NewModelConfig(**tiny_config.to_dict())
|
||||||
|
|
||||||
model = auto_class.from_config(config)
|
model = auto_class.from_config(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
self.assertIsInstance(model, TFNewModel)
|
self.assertIsInstance(model, TFNewModel)
|
||||||
|
|
||||||
|
|||||||
@@ -445,7 +445,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||||
|
|
||||||
|
|||||||
@@ -312,7 +312,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
|||||||
config = self.model_tester.get_config()
|
config = self.model_tester.get_config()
|
||||||
for model_class in self.all_model_classes:
|
for model_class in self.all_model_classes:
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
|
|
||||||
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
||||||
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
||||||
|
|||||||
@@ -217,7 +217,7 @@ class TFCoreModelTesterMixin:
|
|||||||
for model_class in self.all_model_classes[:2]:
|
for model_class in self.all_model_classes[:2]:
|
||||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
model = model_class(config)
|
model = model_class(config)
|
||||||
model.build()
|
model.build_in_name_scope()
|
||||||
num_out = len(model(class_inputs_dict))
|
num_out = len(model(class_inputs_dict))
|
||||||
|
|
||||||
for key in list(class_inputs_dict.keys()):
|
for key in list(class_inputs_dict.keys()):
|
||||||
|
|||||||
Reference in New Issue
Block a user