Proper build() methods for TF (#27794)

* Add a convenience method for building in your own name scope

* Second attempt at auto layer building

* Revert "Second attempt at auto layer building"

This reverts commit e03a3aaecf9ec41a805582b83cbdfe3290a631be.

* Attempt #3

* Revert "Attempt #3"

This reverts commit b9df7a0857560d29b5abbed6127d9e9eca77cf47.

* Add missing attributes that we're going to need later

* Add some attributes we're going to need later

* A fourth attempt! Feel the power flow through you!

* Revert "A fourth attempt! Feel the power flow through you!"

This reverts commit 6bf4aaf3875d6f28485f50187617a4c616c8aff7.

* Add more values we'll need later

* TF refactor that we'll need later

* Revert "TF refactor that we'll need later"

This reverts commit ca07202fb5b7b7436b893baa8d688b4f348ea7b9.

* Revert "Revert "TF refactor that we'll need later""

This reverts commit 1beb0f39f293ed9c27594575e1c849aadeb15c13.

* make fixup

* Attempt five!

* Revert "Attempt five!"

This reverts commit 3302207958dfd0374b0447a51c06eea51a506044.

* Attempt six - this time don't add empty methods

* Revert "Attempt six - this time don't add empty methods"

This reverts commit 67d60129be75416b6beb8f47c7d38d77b18d79bb.

* Attempt seven - better base model class detection!

* Revert "Attempt seven - better base model class detection!"

This reverts commit 5f14845e92ea0e87c598da933bfbfee10f553bc9.

* Another attribute we'll need later

* Try again with the missing attribute!

* Revert "Try again with the missing attribute!"

This reverts commit 760c6f30c5dffb3e04b0e73c34a77d1882a0fef7.

* This is the attempt that will pierce the heavens!

* Revert "This is the attempt that will pierce the heavens!"

This reverts commit c868bb657de057aca7a5260350a3f831fc4dfee6.

* Attempt seven - snag list is steadily decreasing

* Revert "Attempt seven - snag list is steadily decreasing"

This reverts commit 46fbd975deda64429bfb3e5fac4fc0370c00d316.

* Attempt eight - will an empty snag list do it?

* Revert "Attempt eight - will an empty snag list do it?"

This reverts commit 7c8a3c2b083253649569e9877e02054ae5cec67b.

* Fixes to Hubert issues that cause problems later

* Trying again with Conv1D/SeparableConv fixes

* Revert "Trying again with Conv1D/SeparableConv fixes"

This reverts commit 55092bca952bc0f750aa1ffe246a640bf1e2036e.

* Apply the build shape fixes to Wav2Vec2 as well

* One more attempt!

* Revert "One more attempt!"

This reverts commit 5ac3e4cb01b9458cc93312873725f9444ae7261c.

* Another attempt!

* Revert "Another attempt!"

This reverts commit ea16d890e019d7de8792a3b8e72f3b1c02adae50.

* Let's see how many failures we get without the internal build method

* Fix OpenAI

* Fix MobileBERT

* (Mostly) fix GroupVIT

* Fix BLIP

* One more BLIP fix

* One more BLIP fix!

* Fix Regnet

* Finally fully fix GroupViT

* Fix Data2Vec and add the new AdaptivePool

* Fix Segformer

* Fix Albert

* Fix Deberta/DebertaV2

* Fix XLM

* Actually fix XLM

* Fix Flaubert

* Fix lxmert

* Fix Resnet

* Fix ConvBERT

* Fix ESM

* Fix Convnext / ConvnextV2

* Fix SAM

* Fix Efficientformer

* Fix LayoutLMv3

* Fix speech_to_text

* Fix mpnet and mobilevit

* Fix Swin

* Fix CTRL

* Fix CVT

* Fix DPR

* Fix Wav2Vec2

* Fix T5

* Fix Hubert

* Fix GPT2

* Fix Whisper

* Fix DeiT

* Fix the encoder-decoder / dual-encoder classes

* make fix-copies

* build in name scope

* Fix summarization test

* Fix tied weight names for BART + Blenderbot

* Fix tied weight name building

* Fix to TFESM weight building

* Update TF SAM

* Expand all the shapes out into Big Boy Shapes
This commit is contained in:
Matt
2023-12-14 15:17:30 +00:00
committed by GitHub
parent 52c37882fb
commit 050e0b44f6
73 changed files with 11040 additions and 504 deletions

View File

@@ -2161,16 +2161,8 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
raise ValueError("You have to specify either input_ids or inputs_embeds")
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
embed_pos = self.embed_positions(input_shape)
hidden_states = inputs_embeds + embed_pos
@@ -2359,16 +2351,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
positions = self.embed_positions(input_shape, past_key_values_length)
if inputs_embeds is None:
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
# is used with a name ending in `/`, that name replaces the current name scope.
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
context = []
if hasattr(self.embed_tokens, "load_weight_prefix"):
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
with ContextManagers(context):
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
inputs_embeds = self.embed_tokens(input_ids)
hidden_states = inputs_embeds
@@ -2578,6 +2562,13 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
encoder_attentions=encoder_outputs.attentions,
)
def build(self, input_shape=None):
# The shared/tied weights expect to be in the model base namespace
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
# the current one.
with tf.name_scope(self.shared.load_weight_prefix + '/' + self.shared.name + '/'):
self.shared.build(None)
@add_start_docstrings(
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",