New TF embeddings (cleaner and faster) (#9418)
* Create new embeddings + add to BERT * Add Albert * Add DistilBert * Add Albert + Electra + Funnel * Add Longformer + Lxmert * Add last models * Apply style * Update the template * Remove unused imports * Rename attribute * Import embeddings in their own model file * Replace word_embeddings per weight * fix naming * Fix Albert * Fix Albert * Fix Longformer * Fix Lxmert Mobilebert and MPNet * Fix copy * Fix template * Update the get weights function * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/electra/modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address Sylvain's comments Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -760,31 +760,6 @@ class TFModelTesterMixin:
|
||||
model, tuple_inputs, dict_inputs, {"output_hidden_states": True, "output_attentions": True}
|
||||
)
|
||||
|
||||
def _get_embeds(self, wte, input_ids):
|
||||
# ^^ In our TF models, the input_embeddings can take slightly different forms,
|
||||
# so we try a few of them.
|
||||
# We used to fall back to just synthetically creating a dummy tensor of ones:
|
||||
try:
|
||||
x = wte(input_ids, mode="embedding")
|
||||
except Exception:
|
||||
try:
|
||||
x = wte([input_ids], mode="embedding")
|
||||
except Exception:
|
||||
try:
|
||||
x = wte([input_ids, None, None, None], mode="embedding")
|
||||
except Exception:
|
||||
if hasattr(self.model_tester, "embedding_size"):
|
||||
x = tf.ones(
|
||||
input_ids.shape + [self.model_tester.embedding_size],
|
||||
dtype=tf.dtypes.float32,
|
||||
)
|
||||
else:
|
||||
x = tf.ones(
|
||||
input_ids.shape + [self.model_tester.hidden_size],
|
||||
dtype=tf.dtypes.float32,
|
||||
)
|
||||
return x
|
||||
|
||||
def test_inputs_embeds(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
@@ -801,12 +776,11 @@ class TFModelTesterMixin:
|
||||
del inputs["input_ids"]
|
||||
inputs.pop("decoder_input_ids", None)
|
||||
|
||||
wte = model.get_input_embeddings()
|
||||
if not self.is_encoder_decoder:
|
||||
inputs["inputs_embeds"] = self._get_embeds(wte, input_ids)
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(input_ids)
|
||||
else:
|
||||
inputs["inputs_embeds"] = self._get_embeds(wte, encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = self._get_embeds(wte, decoder_input_ids)
|
||||
inputs["inputs_embeds"] = model.get_input_embeddings()(encoder_input_ids)
|
||||
inputs["decoder_inputs_embeds"] = model.get_input_embeddings()(decoder_input_ids)
|
||||
|
||||
model(inputs)
|
||||
|
||||
@@ -837,24 +811,25 @@ class TFModelTesterMixin:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
def _get_word_embedding_weight(model, embedding_layer):
|
||||
if hasattr(embedding_layer, "word_embeddings"):
|
||||
return embedding_layer.word_embeddings
|
||||
elif hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
elif hasattr(embedding_layer, "decoder"):
|
||||
return embedding_layer.decoder
|
||||
else:
|
||||
# Here we build the word embeddings weights if not exists.
|
||||
# And then we retry to get the attribute once built.
|
||||
model(model.dummy_inputs)
|
||||
if hasattr(embedding_layer, "word_embeddings"):
|
||||
return embedding_layer.word_embeddings
|
||||
elif hasattr(embedding_layer, "weight"):
|
||||
return embedding_layer.weight
|
||||
elif hasattr(embedding_layer, "decoder"):
|
||||
return embedding_layer.decoder
|
||||
else:
|
||||
return None
|
||||
embeds = getattr(embedding_layer, "weight", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
embeds = getattr(embedding_layer, "decoder", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
model(model.dummy_inputs)
|
||||
|
||||
embeds = getattr(embedding_layer, "weight", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
embeds = getattr(embedding_layer, "decoder", None)
|
||||
if embeds is not None:
|
||||
return embeds
|
||||
|
||||
return None
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
for size in [config.vocab_size - 10, config.vocab_size + 10, None]:
|
||||
|
||||
Reference in New Issue
Block a user