Clean TF Bert (#9788)

* Start cleaning BERT

* Clean BERT and all those depends of it

* Fix attribute name

* Apply style

* Apply Sylvain's comments

* Apply Lysandre's comments

* remove unused import
This commit is contained in:
Julien Plu
2021-01-27 11:28:11 +01:00
committed by GitHub
parent f0329ea516
commit 4adbdce5ee
15 changed files with 1295 additions and 1059 deletions

View File

@@ -18,6 +18,7 @@
import math
import warnings
from typing import Any, Dict
import tensorflow as tf
@@ -95,16 +96,16 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
self.hidden_size = hidden_size
self.initializer_range = initializer_range
def build(self, input_shape):
def build(self, input_shape: tf.TensorShape):
self.weight = self.add_weight(
name="weight",
shape=[self.vocab_size, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape=input_shape)
super().build(input_shape)
def get_config(self):
def get_config(self) -> Dict[str, Any]:
config = {
"vocab_size": self.vocab_size,
"hidden_size": self.hidden_size,
@@ -114,14 +115,14 @@ class TFMPNetWordEmbeddings(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def call(self, input_ids):
def call(self, input_ids: tf.Tensor) -> tf.Tensor:
flat_input_ids = tf.reshape(tensor=input_ids, shape=[-1])
embeddings = tf.gather(params=self.weight, indices=flat_input_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=input_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(input_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=input_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(input_ids.shape.as_list() + [self.hidden_size])
return embeddings
@@ -139,7 +140,7 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
self.position_embeddings = self.add_weight(
name="embeddings",
shape=[self.max_position_embeddings, self.hidden_size],
initializer=get_initializer(initializer_range=self.initializer_range),
initializer=get_initializer(self.initializer_range),
)
super().build(input_shape)
@@ -158,10 +159,10 @@ class TFMPNetPositionEmbeddings(tf.keras.layers.Layer):
flat_position_ids = tf.reshape(tensor=position_ids, shape=[-1])
embeddings = tf.gather(params=self.position_embeddings, indices=flat_position_ids)
embeddings = tf.reshape(
tensor=embeddings, shape=tf.concat(values=[shape_list(tensor=position_ids), [self.hidden_size]], axis=0)
tensor=embeddings, shape=tf.concat(values=[shape_list(position_ids), [self.hidden_size]], axis=0)
)
embeddings.set_shape(shape=position_ids.shape.as_list() + [self.hidden_size])
embeddings.set_shape(position_ids.shape.as_list() + [self.hidden_size])
return embeddings
@@ -207,8 +208,8 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
tensor=input_ids, shape=(input_ids_shape[0] * input_ids_shape[1], input_ids_shape[2])
)
mask = tf.cast(x=tf.math.not_equal(x=input_ids, y=self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(x=mask, axis=1) * mask
mask = tf.cast(tf.math.not_equal(input_ids, self.padding_idx), dtype=input_ids.dtype)
incremental_indices = tf.math.cumsum(mask, axis=1) * mask
return incremental_indices + self.padding_idx
@@ -253,23 +254,23 @@ class TFMPNetEmbeddings(tf.keras.layers.Layer):
return final_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler
# Copied from transformers.models.bert.modeling_tf_bert.TFBertPooler with Bert->MPNet
class TFMPNetPooler(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.Dense(
config.hidden_size,
units=config.hidden_size,
kernel_initializer=get_initializer(config.initializer_range),
activation="tanh",
name="dense",
)
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.dense(inputs=first_token_tensor)
return pooled_output
@@ -291,28 +292,28 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="q",
)
self.k = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="k",
)
self.v = tf.keras.layers.experimental.EinsumDense(
equation="abc,cde->abde",
output_shape=(None, config.num_attention_heads, self.attention_head_size),
bias_axes="de",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="v",
)
self.o = tf.keras.layers.experimental.EinsumDense(
equation="abcd,cde->abe",
output_shape=(None, self.all_head_size),
bias_axes="e",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="o",
)
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
@@ -322,8 +323,8 @@ class TFMPNetSelfAttention(tf.keras.layers.Layer):
k = self.k(hidden_states)
v = self.v(hidden_states)
dk = tf.cast(x=self.attention_head_size, dtype=q.dtype)
q = tf.multiply(x=q, y=tf.math.rsqrt(x=dk))
dk = tf.cast(self.attention_head_size, dtype=q.dtype)
q = tf.multiply(q, y=tf.math.rsqrt(dk))
attention_scores = tf.einsum("aecd,abcd->acbe", k, q)
# Apply relative position embedding (precomputed in MPNetEncoder) if provided.
@@ -368,34 +369,34 @@ class TFMPNetAttention(tf.keras.layers.Layer):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate
# Copied from transformers.models.bert.modeling_tf_bert.TFBertIntermediate with Bert->MPNet
class TFMPNetIntermediate(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
equation="abc,cd->abd",
output_shape=(None, config.intermediate_size),
bias_axes="d",
kernel_initializer=get_initializer(initializer_range=config.initializer_range),
kernel_initializer=get_initializer(config.initializer_range),
name="dense",
)
if isinstance(config.hidden_act, str):
self.intermediate_act_fn = get_tf_activation(activation_string=config.hidden_act)
self.intermediate_act_fn = get_tf_activation(config.hidden_act)
else:
self.intermediate_act_fn = config.hidden_act
def call(self, hidden_states):
def call(self, hidden_states: tf.Tensor) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
return hidden_states
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput
# Copied from transformers.models.bert.modeling_tf_bert.TFBertOutput with Bert->MPNet
class TFMPNetOutput(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
def __init__(self, config: MPNetConfig, **kwargs):
super().__init__(**kwargs)
self.dense = tf.keras.layers.experimental.EinsumDense(
@@ -408,7 +409,7 @@ class TFMPNetOutput(tf.keras.layers.Layer):
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
self.dropout = tf.keras.layers.Dropout(rate=config.hidden_dropout_prob)
def call(self, hidden_states, input_tensor, training=False):
def call(self, hidden_states: tf.Tensor, input_tensor: tf.Tensor, training: bool = False) -> tf.Tensor:
hidden_states = self.dense(inputs=hidden_states)
hidden_states = self.dropout(inputs=hidden_states, training=training)
hidden_states = self.LayerNorm(inputs=hidden_states + input_tensor)
@@ -563,11 +564,11 @@ class TFMPNetMainLayer(tf.keras.layers.Layer):
self.embeddings = TFMPNetEmbeddings(config, name="embeddings")
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.get_input_embeddings
def get_input_embeddings(self):
def get_input_embeddings(self) -> tf.keras.layers.Layer:
return self.embeddings.word_embeddings
# Copied from transformers.models.bert.modeling_tf_bert.TFBertMainLayer.set_input_embeddings
def set_input_embeddings(self, value):
def set_input_embeddings(self, value: tf.Variable):
self.embeddings.word_embeddings.weight = value
self.embeddings.word_embeddings.vocab_size = shape_list(value)[0]
@@ -820,7 +821,7 @@ class TFMPNetModel(TFMPNetPreTrainedModel):
return outputs
# Copied from transformers.models.bert.modeling_tf_bert.TFBertModel.serving_output
def serving_output(self, output):
def serving_output(self, output: TFBaseModelOutputWithPooling) -> TFBaseModelOutputWithPooling:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -973,7 +974,7 @@ class TFMPNetForMaskedLM(TFMPNetPreTrainedModel, TFMaskedLanguageModelingLoss):
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMaskedLM.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMaskedLMOutput) -> TFMaskedLMOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -1095,7 +1096,7 @@ class TFMPNetForSequenceClassification(TFMPNetPreTrainedModel, TFSequenceClassif
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForSequenceClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFSequenceClassifierOutput) -> TFSequenceClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -1233,7 +1234,7 @@ class TFMPNetForMultipleChoice(TFMPNetPreTrainedModel, TFMultipleChoiceLoss):
return self.serving_output(output)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForMultipleChoice.serving_output
def serving_output(self, output):
def serving_output(self, output: TFMultipleChoiceModelOutput) -> TFMultipleChoiceModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -1333,7 +1334,7 @@ class TFMPNetForTokenClassification(TFMPNetPreTrainedModel, TFTokenClassificatio
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForTokenClassification.serving_output
def serving_output(self, output):
def serving_output(self, output: TFTokenClassifierOutput) -> TFTokenClassifierOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None
@@ -1446,7 +1447,7 @@ class TFMPNetForQuestionAnswering(TFMPNetPreTrainedModel, TFQuestionAnsweringLos
)
# Copied from transformers.models.bert.modeling_tf_bert.TFBertForQuestionAnswering.serving_output
def serving_output(self, output):
def serving_output(self, output: TFQuestionAnsweringModelOutput) -> TFQuestionAnsweringModelOutput:
hs = tf.convert_to_tensor(output.hidden_states) if self.config.output_hidden_states else None
attns = tf.convert_to_tensor(output.attentions) if self.config.output_attentions else None