New TF embeddings (cleaner and faster) (#9418)
* Create new embeddings + add to BERT * Add Albert * Add DistilBert * Add Albert + Electra + Funnel * Add Longformer + Lxmert * Add last models * Apply style * Update the template * Remove unused imports * Rename attribute * Import embeddings in their own model file * Replace word_embeddings per weight * fix naming * Fix Albert * Fix Albert * Fix Longformer * Fix Lxmert Mobilebert and MPNet * Fix copy * Fix template * Update the get weights function * Update src/transformers/modeling_tf_utils.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/electra/modeling_tf_electra.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address Sylvain's comments Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -121,124 +121,174 @@ class TFBertPreTrainingLoss:
|
||||
return masked_lm_loss + next_sentence_loss
|
||||
|
||||
|
||||
class TFBertWordEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape=input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"vocab_size": self.vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, input_ids):
|
||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class TFBertTokenTypeEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, type_vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.type_vocab_size = type_vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
self.token_type_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.type_vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape=input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"type_vocab_size": self.type_vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, token_type_ids):
|
||||
flat_token_type_ids = tf.reshape(tensor=token_type_ids, shape=[-1])
|
||||
one_hot_data = tf.one_hot(indices=flat_token_type_ids, depth=self.type_vocab_size, dtype=self._compute_dtype)
|
||||
embeddings = tf.matmul(a=one_hot_data, b=self.token_type_embeddings)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=token_type_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(shape=token_type_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class TFBertPositionEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"max_position_embeddings": self.max_position_embeddings,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, position_ids):
|
||||
input_shape = shape_list(tensor=position_ids)
|
||||
position_embeddings = self.position_embeddings[: input_shape[1], :]
|
||||
|
||||
return tf.broadcast_to(input=position_embeddings, shape=input_shape)
|
||||
|
||||
|
||||
class TFBertEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position and token_type embeddings."""
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.initializer_range = config.initializer_range
|
||||
self.position_embeddings = tf.keras.layers.Embedding(
|
||||
config.max_position_embeddings,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
self.word_embeddings = TFBertWordEmbeddings(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="word_embeddings",
|
||||
)
|
||||
self.position_embeddings = TFBertPositionEmbeddings(
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.token_type_embeddings = tf.keras.layers.Embedding(
|
||||
config.type_vocab_size,
|
||||
config.hidden_size,
|
||||
embeddings_initializer=get_initializer(self.initializer_range),
|
||||
self.token_type_embeddings = TFBertTokenTypeEmbeddings(
|
||||
type_vocab_size=config.type_vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="token_type_embeddings",
|
||||
)
|
||||
|
||||
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
|
||||
# any TensorFlow checkpoint file
|
||||
self.embeddings_sum = tf.keras.layers.Add()
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape):
|
||||
"""Build shared word embedding layer """
|
||||
with tf.name_scope("word_embeddings"):
|
||||
# Create and initialize weights. The random normal initializer was chosen
|
||||
# arbitrarily, and works well.
|
||||
self.word_embeddings = self.add_weight(
|
||||
"weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def call(
|
||||
self,
|
||||
input_ids=None,
|
||||
position_ids=None,
|
||||
token_type_ids=None,
|
||||
inputs_embeds=None,
|
||||
mode="embedding",
|
||||
training=False,
|
||||
):
|
||||
def call(self, input_ids=None, position_ids=None, token_type_ids=None, inputs_embeds=None, training=False):
|
||||
"""
|
||||
Get token embeddings of inputs.
|
||||
|
||||
Args:
|
||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||
mode: string, a valid value is one of "embedding" and "linear".
|
||||
Applies embedding based on inputs tensor.
|
||||
|
||||
Returns:
|
||||
outputs: If mode == "embedding", output embedding tensor, float32 with shape [batch_size, length,
|
||||
embedding_size]; if mode == "linear", output linear tensor, float32 with shape [batch_size, length,
|
||||
vocab_size].
|
||||
|
||||
Raises:
|
||||
ValueError: if mode is not valid.
|
||||
|
||||
Shared weights logic adapted from
|
||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||
final_embeddings (:obj:`tf.Tensor`): output embedding tensor.
|
||||
"""
|
||||
if mode == "embedding":
|
||||
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||
elif mode == "linear":
|
||||
return self._linear(input_ids)
|
||||
else:
|
||||
raise ValueError("mode {} is not valid.".format(mode))
|
||||
|
||||
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||
"""Applies embedding based on inputs tensor."""
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
input_shape = shape_list(input_ids)
|
||||
else:
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
seq_length = input_shape[1]
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||
inputs_embeds = self.word_embeddings(input_ids=input_ids)
|
||||
|
||||
if token_type_ids is None:
|
||||
token_type_ids = tf.fill(input_shape, 0)
|
||||
input_shape = shape_list(tensor=inputs_embeds)[:-1]
|
||||
token_type_ids = tf.fill(dims=input_shape, value=0)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||
if position_ids is None:
|
||||
position_embeds = self.position_embeddings(position_ids=inputs_embeds)
|
||||
else:
|
||||
position_embeds = self.position_embeddings(position_ids=position_ids)
|
||||
|
||||
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
|
||||
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
|
||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||
embeddings = self.LayerNorm(embeddings)
|
||||
embeddings = self.dropout(embeddings, training=training)
|
||||
token_type_embeds = self.token_type_embeddings(token_type_ids=token_type_ids)
|
||||
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds, token_type_embeds])
|
||||
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
||||
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
||||
|
||||
return embeddings
|
||||
|
||||
def _linear(self, inputs):
|
||||
"""
|
||||
Computes logits by running inputs through a linear layer.
|
||||
|
||||
Args:
|
||||
inputs: A float32 tensor with shape [batch_size, length, hidden_size].
|
||||
|
||||
Returns:
|
||||
float32 tensor with shape [batch_size, length, vocab_size].
|
||||
"""
|
||||
batch_size = shape_list(inputs)[0]
|
||||
length = shape_list(inputs)[1]
|
||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||
|
||||
return tf.reshape(logits, [batch_size, length, self.vocab_size])
|
||||
return final_embeddings
|
||||
|
||||
|
||||
class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
@@ -251,8 +301,8 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
|
||||
self.query = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
@@ -318,9 +368,9 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.all_head_size = config.num_attention_heads * self.attention_head_size
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abcd,cde->abe",
|
||||
output_shape=(None, self.all_head_size),
|
||||
@@ -516,6 +566,8 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
self.transform = TFBertPredictionHeadTransform(config, name="transform")
|
||||
|
||||
# The output weights are the same as the input embeddings, but there is
|
||||
@@ -531,7 +583,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
return self.input_embeddings
|
||||
|
||||
def set_output_embeddings(self, value):
|
||||
self.input_embeddings.word_embeddings = value
|
||||
self.input_embeddings.weight = value
|
||||
self.input_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def get_bias(self):
|
||||
@@ -542,9 +594,12 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
||||
self.vocab_size = shape_list(value["bias"])[0]
|
||||
|
||||
def call(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||
hidden_states = hidden_states + self.bias
|
||||
hidden_states = self.transform(hidden_states=hidden_states)
|
||||
seq_length = shape_list(tensor=hidden_states)[1]
|
||||
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, self.hidden_size])
|
||||
hidden_states = tf.matmul(a=hidden_states, b=self.input_embeddings.weight, transpose_b=True)
|
||||
hidden_states = tf.reshape(tensor=hidden_states, shape=[-1, seq_length, self.vocab_size])
|
||||
hidden_states = tf.nn.bias_add(value=hidden_states, bias=self.bias)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -583,21 +638,17 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.config = config
|
||||
self.num_hidden_layers = config.num_hidden_layers
|
||||
self.initializer_range = config.initializer_range
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.return_dict = config.use_return_dict
|
||||
|
||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||
self.encoder = TFBertEncoder(config, name="encoder")
|
||||
self.pooler = TFBertPooler(config, name="pooler") if add_pooling_layer else None
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
def set_input_embeddings(self, value):
|
||||
self.embeddings.word_embeddings = value
|
||||
self.embeddings.vocab_size = shape_list(value)[0]
|
||||
self.embeddings.word_embeddings.weight = value
|
||||
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
"""
|
||||
@@ -682,7 +733,7 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
||||
if inputs["head_mask"] is not None:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
inputs["head_mask"] = [None] * self.num_hidden_layers
|
||||
inputs["head_mask"] = [None] * self.config.num_hidden_layers
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
embedding_output,
|
||||
@@ -931,7 +982,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
|
||||
self.bert = TFBertMainLayer(config, name="bert")
|
||||
self.nsp = TFBertNSPHead(config, name="nsp___cls")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings.word_embeddings, name="mlm___cls")
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
@@ -1055,7 +1106,7 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
)
|
||||
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings.word_embeddings, name="mlm___cls")
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
@@ -1158,7 +1209,7 @@ class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
logger.warning("If you want to use `TFBertLMHeadModel` as a standalone, add `is_decoder=True.`")
|
||||
|
||||
self.bert = TFBertMainLayer(config, add_pooling_layer=False, name="bert")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings, name="mlm___cls")
|
||||
self.mlm = TFBertMLMHead(config, self.bert.embeddings.word_embeddings, name="mlm___cls")
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.mlm.predictions
|
||||
|
||||
Reference in New Issue
Block a user