Restore TF embeddings and attention layers to their previous version (#9890)
* Refacto BERT * Restore all the concerned models * Remove print * Update template * Apply Sylvain's and Morgan's comments * Fix cast * Put the cast inside call * Remove cond in ebds * Fix funnel * Restore previous dot product (attention_scores) computation * Add ConvBERT and BART * Make all the S2S models ONNX compliant * Fix test * Fix check copies
This commit is contained in:
@@ -18,7 +18,6 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -87,86 +86,6 @@ class TFMPNetPreTrainedModel(TFPreTrainedModel):
|
||||
return self.serving_output(output)
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertWordEmbeddings
|
||||
class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, vocab_size: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config = {
|
||||
"vocab_size": self.vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
|
||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.longformer.modeling_tf_longformer.TFLongformerPositionEmbeddings
|
||||
class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
|
||||
def __init__(self, max_position_embeddings: int, hidden_size: int, initializer_range: float, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
config = {
|
||||
"max_position_embeddings": self.max_position_embeddings,
|
||||
"hidden_size": self.hidden_size,
|
||||
"initializer_range": self.initializer_range,
|
||||
}
|
||||
base_config = super().get_config()
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, position_ids):
|
||||
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
"""Construct the embeddings from word, position embeddings."""
|
||||
|
||||
@@ -174,22 +93,31 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.padding_idx = 1
|
||||
self.word_embeddings = TFMPNetWordEmbeddings(
|
||||
vocab_size=config.vocab_size,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="word_embeddings",
|
||||
)
|
||||
self.position_embeddings = TFMPNetPositionEmbeddings(
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
hidden_size=config.hidden_size,
|
||||
initializer_range=config.initializer_range,
|
||||
name="position_embeddings",
|
||||
)
|
||||
self.vocab_size = config.vocab_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.initializer_range = config.initializer_range
|
||||
self.embeddings_sum = tf.keras.layers.Add()
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
with tf.name_scope("word_embeddings"):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
with tf.name_scope("position_embeddings"):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
|
||||
def create_position_ids_from_input_ids(self, input_ids):
|
||||
"""
|
||||
Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding
|
||||
@@ -197,36 +125,13 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
Args:
|
||||
input_ids: tf.Tensor
|
||||
|
||||
Returns: tf.Tensor
|
||||
"""
|
||||
input_ids_shape = shape_list(tensor=input_ids)
|
||||
|
||||
# multiple choice has 3 dimensions
|
||||
if len(input_ids_shape) == 3:
|
||||
input_ids = tf.reshape(
|
||||
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
|
||||
)
|
||||
|
||||
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
||||
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
|
||||
|
||||
return incremental_indices + self.padding_idx
|
||||
|
||||
def create_position_ids_from_inputs_embeds(self, inputs_embeds):
|
||||
"""
|
||||
We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
|
||||
|
||||
Args:
|
||||
inputs_embeds: tf.Tensor
|
||||
|
||||
Returns: tf.Tensor
|
||||
"""
|
||||
batch_size, seq_length = shape_list(tensor=inputs_embeds)[:2]
|
||||
position_ids = tf.range(start=self.padding_idx + 1, limit=seq_length + self.padding_idx + 1)[tf.newaxis, :]
|
||||
|
||||
return tf.tile(input=position_ids, multiples=(batch_size, 1))
|
||||
|
||||
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, training=False):
|
||||
"""
|
||||
Applies embedding based on inputs tensor.
|
||||
@@ -237,16 +142,21 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
assert not (input_ids is None and inputs_embeds is None)
|
||||
|
||||
if input_ids is not None:
|
||||
inputs_embeds = self.word_embeddings(input_ids=input_ids)
|
||||
inputs_embeds = tf.gather(params=self.weight, indices=input_ids)
|
||||
|
||||
input_shape = shape_list(inputs_embeds)[:-1]
|
||||
|
||||
if position_ids is None:
|
||||
if input_ids is not None:
|
||||
# Create the position ids from the input token ids. Any padded tokens remain padded.
|
||||
position_ids = self.create_position_ids_from_input_ids(input_ids=input_ids)
|
||||
else:
|
||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds=inputs_embeds)
|
||||
position_ids = tf.range(start=self.padding_idx + 1, limit=input_shape[-1] + self.padding_idx + 1)[
|
||||
tf.newaxis, :
|
||||
]
|
||||
position_ids = tf.tile(input=position_ids, multiples=(input_shape[0], 1))
|
||||
|
||||
position_embeds = self.position_embeddings(position_ids=position_ids)
|
||||
position_embeds = tf.gather(params=self.position_embeddings, indices=position_ids)
|
||||
final_embeddings = self.embeddings_sum(inputs=[inputs_embeds, position_embeds])
|
||||
final_embeddings = self.LayerNorm(inputs=final_embeddings)
|
||||
final_embeddings = self.dropout(inputs=final_embeddings, training=training)
|
||||
@@ -281,58 +191,55 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
|
||||
if config.hidden_size % config.num_attention_heads != 0:
|
||||
raise ValueError(
|
||||
f"The hidden size ({config.hidden_size}) is not a multiple of the number "
|
||||
f"of attention heads ({config.num_attention_heads})"
|
||||
"The hidden size (%d) is not a multiple of the number of attention "
|
||||
"heads (%d)" % (config.hidden_size, config.num_attention_heads)
|
||||
)
|
||||
|
||||
self.num_attention_heads = config.num_attention_heads
|
||||
assert config.hidden_size % config.num_attention_heads == 0
|
||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||
self.q = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="q",
|
||||
|
||||
self.q = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="q"
|
||||
)
|
||||
self.k = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="k",
|
||||
self.k = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="k"
|
||||
)
|
||||
self.v = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="v",
|
||||
self.v = tf.keras.layers.Dense(
|
||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="v"
|
||||
)
|
||||
self.o = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abcd,cde->abe",
|
||||
output_shape=(None, self.all_head_size),
|
||||
bias_axes="e",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="o",
|
||||
self.o = tf.keras.layers.Dense(
|
||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="o"
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
|
||||
def transpose_for_scores(self, x, batch_size):
|
||||
# Reshape from [batch_size, seq_length, all_head_size] to [batch_size, seq_length, num_attention_heads, attention_head_size]
|
||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||
|
||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||
|
||||
def call(self, hidden_states, attention_mask, head_mask, output_attentions, position_bias=None, training=False):
|
||||
batch_size = shape_list(hidden_states)[0]
|
||||
|
||||
q = self.q(hidden_states)
|
||||
k = self.k(hidden_states)
|
||||
v = self.v(hidden_states)
|
||||
|
||||
dk = tf.cast(self.attention_head_size, dtype=q.dtype)
|
||||
q = tf.multiply(q, y=tf.math.rsqrt(dk))
|
||||
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
|
||||
q = self.transpose_for_scores(q, batch_size)
|
||||
k = self.transpose_for_scores(k, batch_size)
|
||||
v = self.transpose_for_scores(v, batch_size)
|
||||
|
||||
attention_scores = tf.matmul(q, k, transpose_b=True)
|
||||
dk = tf.cast(shape_list(k)[-1], attention_scores.dtype)
|
||||
attention_scores = attention_scores / tf.math.sqrt(dk)
|
||||
|
||||
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
|
||||
if position_bias is not None:
|
||||
attention_scores += position_bias
|
||||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in TFMPNetModel call() function)
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
attention_probs = tf.nn.softmax(attention_scores, axis=-1)
|
||||
@@ -342,7 +249,9 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
if head_mask is not None:
|
||||
attention_probs = attention_probs * head_mask
|
||||
|
||||
c = tf.einsum("acbe,aecd->abcd", attention_probs, v)
|
||||
c = tf.matmul(attention_probs, v)
|
||||
c = tf.transpose(c, perm=[0, 2, 1, 3])
|
||||
c = tf.reshape(c, (batch_size, -1, self.all_head_size))
|
||||
o = self.o(c)
|
||||
|
||||
outputs = (o, attention_probs) if output_attentions else (o,)
|
||||
@@ -374,12 +283,8 @@ class TFMPNetIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config: MPNetConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
output_shape=(None, config.intermediate_size),
|
||||
bias_axes="d",
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
@@ -399,12 +304,8 @@ class TFMPNetOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config: MPNetConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
bias_axes="d",
|
||||
output_shape=(None, config.hidden_size),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
units=config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||
)
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
@@ -565,12 +466,12 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
return self.embeddings.word_embeddings
|
||||
return self.embeddings
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
|
||||
def set_input_embeddings(self, value: tf.Variable):
|
||||
self.embeddings.word_embeddings.weight = value
|
||||
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
|
||||
self.embeddings.weight = value
|
||||
self.embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer._prune_heads
|
||||
def _prune_heads(self, heads_to_prune):
|
||||
@@ -894,7 +795,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
||||
self.mpnet = TFMPNetMainLayer(config, name="mpnet")
|
||||
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings.word_embeddings, name="lm_head")
|
||||
self.lm_head = TFMPNetLMHead(config, self.mpnet.embeddings, name="lm_head")
|
||||
|
||||
def get_lm_head(self):
|
||||
return self.lm_head
|
||||
|
||||
Reference in New Issue
Block a user