Proper build() methods for TF (#27794)
* Add a convenience method for building in your own name scope * Second attempt at auto layer building * Revert "Second attempt at auto layer building" This reverts commit e03a3aaecf9ec41a805582b83cbdfe3290a631be. * Attempt #3 * Revert "Attempt #3" This reverts commit b9df7a0857560d29b5abbed6127d9e9eca77cf47. * Add missing attributes that we're going to need later * Add some attributes we're going to need later * A fourth attempt! Feel the power flow through you! * Revert "A fourth attempt! Feel the power flow through you!" This reverts commit 6bf4aaf3875d6f28485f50187617a4c616c8aff7. * Add more values we'll need later * TF refactor that we'll need later * Revert "TF refactor that we'll need later" This reverts commit ca07202fb5b7b7436b893baa8d688b4f348ea7b9. * Revert "Revert "TF refactor that we'll need later"" This reverts commit 1beb0f39f293ed9c27594575e1c849aadeb15c13. * make fixup * Attempt five! * Revert "Attempt five!" This reverts commit 3302207958dfd0374b0447a51c06eea51a506044. * Attempt six - this time don't add empty methods * Revert "Attempt six - this time don't add empty methods" This reverts commit 67d60129be75416b6beb8f47c7d38d77b18d79bb. * Attempt seven - better base model class detection! * Revert "Attempt seven - better base model class detection!" This reverts commit 5f14845e92ea0e87c598da933bfbfee10f553bc9. * Another attribute we'll need later * Try again with the missing attribute! * Revert "Try again with the missing attribute!" This reverts commit 760c6f30c5dffb3e04b0e73c34a77d1882a0fef7. * This is the attempt that will pierce the heavens! * Revert "This is the attempt that will pierce the heavens!" This reverts commit c868bb657de057aca7a5260350a3f831fc4dfee6. * Attempt seven - snag list is steadily decreasing * Revert "Attempt seven - snag list is steadily decreasing" This reverts commit 46fbd975deda64429bfb3e5fac4fc0370c00d316. * Attempt eight - will an empty snag list do it? * Revert "Attempt eight - will an empty snag list do it?" This reverts commit 7c8a3c2b083253649569e9877e02054ae5cec67b. * Fixes to Hubert issues that cause problems later * Trying again with Conv1D/SeparableConv fixes * Revert "Trying again with Conv1D/SeparableConv fixes" This reverts commit 55092bca952bc0f750aa1ffe246a640bf1e2036e. * Apply the build shape fixes to Wav2Vec2 as well * One more attempt! * Revert "One more attempt!" This reverts commit 5ac3e4cb01b9458cc93312873725f9444ae7261c. * Another attempt! * Revert "Another attempt!" This reverts commit ea16d890e019d7de8792a3b8e72f3b1c02adae50. * Let's see how many failures we get without the internal build method * Fix OpenAI * Fix MobileBERT * (Mostly) fix GroupVIT * Fix BLIP * One more BLIP fix * One more BLIP fix! * Fix Regnet * Finally fully fix GroupViT * Fix Data2Vec and add the new AdaptivePool * Fix Segformer * Fix Albert * Fix Deberta/DebertaV2 * Fix XLM * Actually fix XLM * Fix Flaubert * Fix lxmert * Fix Resnet * Fix ConvBERT * Fix ESM * Fix Convnext / ConvnextV2 * Fix SAM * Fix Efficientformer * Fix LayoutLMv3 * Fix speech_to_text * Fix mpnet and mobilevit * Fix Swin * Fix CTRL * Fix CVT * Fix DPR * Fix Wav2Vec2 * Fix T5 * Fix Hubert * Fix GPT2 * Fix Whisper * Fix DeiT * Fix the encoder-decoder / dual-encoder classes * make fix-copies * build in name scope * Fix summarization test * Fix tied weight names for BART + Blenderbot * Fix tied weight name building * Fix to TFESM weight building * Update TF SAM * Expand all the shapes out into Big Boy Shapes
This commit is contained in:
@@ -2161,16 +2161,8 @@ class TF{{cookiecutter.camelcase_modelname}}Encoder(tf.keras.layers.Layer):
|
||||
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
||||
|
||||
if inputs_embeds is None:
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
|
||||
inputs_embeds = self.embed_tokens(input_ids) * self.embed_scale
|
||||
|
||||
embed_pos = self.embed_positions(input_shape)
|
||||
hidden_states = inputs_embeds + embed_pos
|
||||
@@ -2359,16 +2351,8 @@ class TF{{cookiecutter.camelcase_modelname}}Decoder(tf.keras.layers.Layer):
|
||||
positions = self.embed_positions(input_shape, past_key_values_length)
|
||||
|
||||
if inputs_embeds is None:
|
||||
# if `self.embed_tokens.load_weight_prefix` is set, runs the embedding operation with the correct name
|
||||
# scope, so that its weights are registered with the desired name for loading/storing. When `tf.name_scope`
|
||||
# is used with a name ending in `/`, that name replaces the current name scope.
|
||||
# (embeddings with tf.name_scope: self.embed_tokens.load_weight_prefix/self.embed_tokens.name/embeddings:0)
|
||||
context = []
|
||||
if hasattr(self.embed_tokens, "load_weight_prefix"):
|
||||
context.append(tf.name_scope(self.embed_tokens.load_weight_prefix + "/"))
|
||||
with ContextManagers(context):
|
||||
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
check_embeddings_within_bounds(input_ids, self.embed_tokens.input_dim)
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
@@ -2578,6 +2562,13 @@ class TF{{cookiecutter.camelcase_modelname}}MainLayer(tf.keras.layers.Layer):
|
||||
encoder_attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
def build(self, input_shape=None):
|
||||
# The shared/tied weights expect to be in the model base namespace
|
||||
# Adding "/" to the end (not the start!) of a tf.name_scope puts it in the root namespace rather than
|
||||
# the current one.
|
||||
with tf.name_scope(self.shared.load_weight_prefix + '/' + self.shared.name + '/'):
|
||||
self.shared.build(None)
|
||||
|
||||
|
||||
@add_start_docstrings(
|
||||
"The bare {{cookiecutter.uppercase_modelname}} Model outputting raw hidden-states without any specific head on top.",
|
||||
|
||||
Reference in New Issue
Block a user