Move TF building to an actual build() method (#23760)
* A fun new PR where I break the entire codebase again * A fun new PR where I break the entire codebase again * Handle cross-attention * Move calls to model(model.dummy_inputs) to the new build() method * Seeing what fails with the build context thing * make fix-copies * Let's see what fails with new build methods * Fix the pytorch crossload build calls * Fix the overridden build methods in vision_text_dual_encoder * Make sure all our build methods set self.built or call super().build(), which also sets it * make fix-copies * Remove finished TODO * Tentatively remove unneeded (?) line * Transpose b in deberta correctly and remove unused threading local * Get rid of build_with_dummies and all it stands for * Rollback some changes to TF-PT crossloading * Correctly call super().build()
This commit is contained in:
@@ -328,7 +328,7 @@ class TFBartModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
old_total_size = config.vocab_size
|
||||
new_total_size = old_total_size + new_tokens_size
|
||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||
model(model.dummy_inputs) # builds the embeddings layer
|
||||
model.build()
|
||||
model.resize_token_embeddings(new_total_size)
|
||||
|
||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||
|
||||
@@ -1070,9 +1070,9 @@ class TFEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
|
||||
# create two random BERT models for bert2bert & initialize weights (+cross_attention weights)
|
||||
encoder = TFBertModel(config.encoder)
|
||||
encoder(encoder.dummy_inputs)
|
||||
encoder.build()
|
||||
decoder = TFBertLMHeadModel(config.decoder)
|
||||
decoder(decoder.dummy_inputs)
|
||||
decoder.build()
|
||||
|
||||
encoder_decoder_orig = TFEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
|
||||
@@ -463,7 +463,7 @@ class TFGPT2ModelTest(TFModelTesterMixin, TFCoreModelTesterMixin, PipelineTester
|
||||
continue
|
||||
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ class TFOPTModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestCase)
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
if hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
else:
|
||||
|
||||
@@ -729,9 +729,9 @@ class TFVisionEncoderDecoderModelSaveLoadTests(unittest.TestCase):
|
||||
|
||||
# create two random ViT/GPT2 models for vit-gpt2 & initialize weights (+cross_attention weights)
|
||||
encoder = TFViTModel(config.encoder)
|
||||
encoder(encoder.dummy_inputs)
|
||||
encoder.build()
|
||||
decoder = TFGPT2LMHeadModel(config.decoder)
|
||||
decoder(decoder.dummy_inputs)
|
||||
decoder.build()
|
||||
|
||||
encoder_decoder_orig = TFVisionEncoderDecoderModel(encoder=encoder, decoder=decoder)
|
||||
|
||||
|
||||
@@ -281,7 +281,7 @@ class TFWhisperModelTest(TFModelTesterMixin, PipelineTesterMixin, unittest.TestC
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname, saved_model=False)
|
||||
|
||||
@@ -348,7 +348,7 @@ class TFModelTesterMixin:
|
||||
|
||||
with tf.Graph().as_default() as g:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
for op in g.get_operations():
|
||||
model_op_names.add(op.node_def.op)
|
||||
@@ -375,7 +375,7 @@ class TFModelTesterMixin:
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
onnx_model_proto, _ = tf2onnx.convert.from_keras(model, opset=self.onnx_min_opset)
|
||||
|
||||
@@ -1180,7 +1180,7 @@ class TFModelTesterMixin:
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if isinstance(embedding_layer, tf.keras.layers.Embedding):
|
||||
# builds the embeddings layer
|
||||
model(model.dummy_inputs)
|
||||
model.build()
|
||||
return embedding_layer.embeddings
|
||||
else:
|
||||
return model._get_word_embedding_weight(embedding_layer)
|
||||
@@ -1243,7 +1243,7 @@ class TFModelTesterMixin:
|
||||
old_total_size = config.vocab_size
|
||||
new_total_size = old_total_size + new_tokens_size
|
||||
model = model_class(config=copy.deepcopy(config)) # `resize_token_embeddings` mutates `config`
|
||||
model(model.dummy_inputs) # builds the embeddings layer
|
||||
model.build()
|
||||
model.resize_token_embeddings(new_total_size)
|
||||
|
||||
# fetch the output for an input exclusively made of new members of the vocabulary
|
||||
@@ -2313,8 +2313,8 @@ class UtilsFunctionsTest(unittest.TestCase):
|
||||
# Finally, check the model can be reloaded
|
||||
new_model = TFBertModel.from_pretrained(tmp_dir)
|
||||
|
||||
model(model.dummy_inputs)
|
||||
new_model(model.dummy_inputs)
|
||||
model.build()
|
||||
new_model.build()
|
||||
|
||||
for p1, p2 in zip(model.weights, new_model.weights):
|
||||
self.assertTrue(np.allclose(p1.numpy(), p2.numpy()))
|
||||
@@ -2440,7 +2440,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
_ = model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger("transformers.utils.hub")
|
||||
@@ -2509,7 +2509,7 @@ class TFModelPushToHubTester(unittest.TestCase):
|
||||
)
|
||||
model = TFBertModel(config)
|
||||
# Make sure model is properly initialized
|
||||
_ = model(model.dummy_inputs)
|
||||
model.build()
|
||||
|
||||
model.push_to_hub("valid_org/test-model-tf-org", use_auth_token=self._token)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user