Clean TF Bert (#9788)
* Start cleaning BERT * Clean BERT and all those depends of it * Fix attribute name * Apply style * Apply Sylvain's comments * Apply Lysandre's comments * remove unused import
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
|
||||
import math
|
||||
import warnings
|
||||
from typing import Any, Dict
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
@@ -95,16 +96,16 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
|
||||
self.hidden_size = hidden_size
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
def build(self, input_shape):
|
||||
def build(self, input_shape: tf.TensorShape):
|
||||
self.weight = self.add_weight(
|
||||
name="weight",
|
||||
shape=[self.vocab_size, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape=input_shape)
|
||||
super().build(input_shape)
|
||||
|
||||
def get_config(self):
|
||||
def get_config(self) -> Dict[str, Any]:
|
||||
config = {
|
||||
"vocab_size": self.vocab_size,
|
||||
"hidden_size": self.hidden_size,
|
||||
@@ -114,14 +115,14 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
|
||||
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
def call(self, input_ids):
|
||||
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
|
||||
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
|
||||
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -139,7 +140,7 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
|
||||
self.position_embeddings = self.add_weight(
|
||||
name="embeddings",
|
||||
shape=[self.max_position_embeddings, self.hidden_size],
|
||||
initializer=get_initializer(initializer_range=self.initializer_range),
|
||||
initializer=get_initializer(self.initializer_range),
|
||||
)
|
||||
|
||||
super().build(input_shape)
|
||||
@@ -158,10 +159,10 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
|
||||
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
|
||||
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
|
||||
embeddings = tf.reshape(
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=position_ids), [self.hidden_size]], axis=0)
|
||||
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
|
||||
)
|
||||
|
||||
embeddings.set_shape(shape=position_ids.shape.as_list() + [self.hidden_size])
|
||||
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
|
||||
|
||||
return embeddings
|
||||
|
||||
@@ -207,8 +208,8 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
|
||||
)
|
||||
|
||||
mask = tf.cast(x=tf.math.not_equal(x=input_ids, y=self.padding_idx), dtype=input_ids.dtype)
|
||||
incremental_indices = tf.math.cumsum(x=mask, axis=1) * mask
|
||||
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
|
||||
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
|
||||
|
||||
return incremental_indices + self.padding_idx
|
||||
|
||||
@@ -253,23 +254,23 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
|
||||
return final_embeddings
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->MPNet
|
||||
class TFMPNetPooler(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config: MPNetConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.Dense(
|
||||
config.hidden_size,
|
||||
units=config.hidden_size,
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
activation="tanh",
|
||||
name="dense",
|
||||
)
|
||||
|
||||
def call(self, hidden_states):
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
# We "pool" the model by simply taking the hidden state corresponding
|
||||
# to the first token.
|
||||
first_token_tensor = hidden_states[:, 0]
|
||||
pooled_output = self.dense(first_token_tensor)
|
||||
pooled_output = self.dense(inputs=first_token_tensor)
|
||||
|
||||
return pooled_output
|
||||
|
||||
@@ -291,28 +292,28 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="q",
|
||||
)
|
||||
self.k = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="k",
|
||||
)
|
||||
self.v = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cde->abde",
|
||||
output_shape=(None, config.num_attention_heads, self.attention_head_size),
|
||||
bias_axes="de",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="v",
|
||||
)
|
||||
self.o = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abcd,cde->abe",
|
||||
output_shape=(None, self.all_head_size),
|
||||
bias_axes="e",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="o",
|
||||
)
|
||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||
@@ -322,8 +323,8 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
|
||||
k = self.k(hidden_states)
|
||||
v = self.v(hidden_states)
|
||||
|
||||
dk = tf.cast(x=self.attention_head_size, dtype=q.dtype)
|
||||
q = tf.multiply(x=q, y=tf.math.rsqrt(x=dk))
|
||||
dk = tf.cast(self.attention_head_size, dtype=q.dtype)
|
||||
q = tf.multiply(q, y=tf.math.rsqrt(dk))
|
||||
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
|
||||
|
||||
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
|
||||
@@ -368,34 +369,34 @@ class TFMPNetAttention(tf.keras.layers.Layer):
|
||||
return outputs
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet
|
||||
class TFMPNetIntermediate(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config: MPNetConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
equation="abc,cd->abd",
|
||||
output_shape=(None, config.intermediate_size),
|
||||
bias_axes="d",
|
||||
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
|
||||
kernel_initializer=get_initializer(config.initializer_range),
|
||||
name="dense",
|
||||
)
|
||||
|
||||
if isinstance(config.hidden_act, str):
|
||||
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
|
||||
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def call(self, hidden_states):
|
||||
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet
|
||||
class TFMPNetOutput(tf.keras.layers.Layer):
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config: MPNetConfig, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.dense = tf.keras.layers.experimental.EinsumDense(
|
||||
@@ -408,7 +409,7 @@ class TFMPNetOutput(tf.keras.layers.Layer):
|
||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
|
||||
|
||||
def call(self, hidden_states, input_tensor, training=False):
|
||||
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
|
||||
hidden_states = self.dense(inputs=hidden_states)
|
||||
hidden_states = self.dropout(inputs=hidden_states, training=training)
|
||||
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
|
||||
@@ -563,11 +564,11 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
|
||||
self.embeddings = TFMPNetEmbeddings(config, name="embeddings")
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
|
||||
def get_input_embeddings(self):
|
||||
def get_input_embeddings(self) -> tf.keras.layers.Layer:
|
||||
return self.embeddings.word_embeddings
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
|
||||
def set_input_embeddings(self, value):
|
||||
def set_input_embeddings(self, value: tf.Variable):
|
||||
self.embeddings.word_embeddings.weight = value
|
||||
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
|
||||
|
||||
@@ -820,7 +821,7 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
|
||||
return outputs
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -973,7 +974,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -1095,7 +1096,7 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -1233,7 +1234,7 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
|
||||
return self.serving_output(output)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -1333,7 +1334,7 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
@@ -1446,7 +1447,7 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
|
||||
)
|
||||
|
||||
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
|
||||
def serving_output(self, output):
|
||||
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
|
||||
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
|
||||
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user