More TF fixes (#28081)
* More build_in_name_scope() * Make sure we set the save spec now we don't do it with dummies anymore * make fixup
This commit is contained in:
@@ -1147,6 +1147,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
|
||||
self._set_save_spec(self.input_signature)
|
||||
|
||||
def get_config(self):
|
||||
return self.config.to_dict()
|
||||
|
||||
@@ -211,7 +211,7 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
config = copy.deepcopy(model.config)
|
||||
config.architectures = ["FunnelBaseModel"]
|
||||
model = TFAutoModel.from_config(config)
|
||||
model.build()
|
||||
model.build_in_name_scope()
|
||||
|
||||
self.assertIsInstance(model, TFFunnelBaseModel)
|
||||
|
||||
@@ -249,7 +249,7 @@ class TFAutoModelTest(unittest.TestCase):
|
||||
config = NewModelConfig(**tiny_config.to_dict())
|
||||
|
||||
model = auto_class.from_config(config)
|
||||
model.build()
|
||||
model.build_in_name_scope()
|
||||
|
||||
self.assertIsInstance(model, TFNewModel)
|
||||
|
||||
|
||||
@@ -445,7 +445,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model.build()
|
||||
model.build_in_name_scope()
|
||||
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
|
||||
@@ -312,7 +312,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
config = self.model_tester.get_config()
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model.build()
|
||||
model.build_in_name_scope()
|
||||
|
||||
embeds = model.get_encoder().embed_positions.get_weights()[0]
|
||||
sinusoids = sinusoidal_embedding_init(embeds.shape).numpy()
|
||||
|
||||
@@ -217,7 +217,7 @@ class TFCoreModelTesterMixin:
|
||||
for model_class in self.all_model_classes[:2]:
|
||||
class_inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||
model = model_class(config)
|
||||
model.build()
|
||||
model.build_in_name_scope()
|
||||
num_out = len(model(class_inputs_dict))
|
||||
|
||||
for key in list(class_inputs_dict.keys()):
|
||||
|
||||
Reference in New Issue
Block a user