Fix saved model creation (#5468)
* Fix TF Serving when output_hidden_states and output_attentions are True * Add tests for saved model creation + bug fix for multiple choices models * remove unused import * Fix the input for several layers * Fix test * Fix conflict printing * Apply style * Fix XLM and Flaubert for TensorFlow * Apply style * Fix TF check version * Apply style * Trigger CI
This commit is contained in:
@@ -35,7 +35,6 @@ from .modeling_tf_utils import (
|
|||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -99,7 +98,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
mode="embedding",
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||||
@@ -115,15 +122,15 @@ class TFAlbertEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, training=training)
|
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = shape_list(input_ids)
|
input_shape = shape_list(input_ids)
|
||||||
@@ -175,6 +182,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||||||
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
|
), f"Hidden size {config.hidden_size} not dividable by number of heads {config.num_attention_heads}"
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.query = tf.keras.layers.Dense(
|
self.query = tf.keras.layers.Dense(
|
||||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||||
@@ -192,9 +200,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
@@ -233,9 +239,7 @@ class TFAlbertSelfAttention(tf.keras.layers.Layer):
|
|||||||
context_layer, (batch_size, -1, self.all_head_size)
|
context_layer, (batch_size, -1, self.all_head_size)
|
||||||
) # (batch_size, seq_len_q, all_head_size)
|
) # (batch_size, seq_len_q, all_head_size)
|
||||||
|
|
||||||
outputs = (
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
|
||||||
)
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -248,9 +252,7 @@ class TFAlbertSelfOutput(tf.keras.layers.Layer):
|
|||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, input_tensor, training=False):
|
||||||
hidden_states, input_tensor = inputs
|
|
||||||
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
@@ -262,6 +264,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||||||
super().__init__(config, **kwargs)
|
super().__init__(config, **kwargs)
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
@@ -271,9 +274,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||||
input_tensor, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
batch_size = shape_list(input_tensor)[0]
|
batch_size = shape_list(input_tensor)[0]
|
||||||
mixed_query_layer = self.query(input_tensor)
|
mixed_query_layer = self.query(input_tensor)
|
||||||
mixed_key_layer = self.key(input_tensor)
|
mixed_key_layer = self.key(input_tensor)
|
||||||
@@ -312,9 +313,7 @@ class TFAlbertAttention(TFBertSelfAttention):
|
|||||||
context_layer, (batch_size, -1, self.all_head_size)
|
context_layer, (batch_size, -1, self.all_head_size)
|
||||||
) # (batch_size, seq_len_q, all_head_size)
|
) # (batch_size, seq_len_q, all_head_size)
|
||||||
|
|
||||||
self_outputs = (
|
self_outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self_outputs[0]
|
hidden_states = self_outputs[0]
|
||||||
|
|
||||||
@@ -349,11 +348,9 @@ class TFAlbertLayer(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
[hidden_states, attention_mask, head_mask, output_attentions], training=training
|
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||||
)
|
)
|
||||||
ffn_output = self.ffn(attention_outputs[0])
|
ffn_output = self.ffn(attention_outputs[0])
|
||||||
ffn_output = self.activation(ffn_output)
|
ffn_output = self.activation(ffn_output)
|
||||||
@@ -371,32 +368,32 @@ class TFAlbertLayerGroup(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.albert_layers = [
|
self.albert_layers = [
|
||||||
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
|
TFAlbertLayer(config, name="albert_layers_._{}".format(i)) for i in range(config.inner_group_num)
|
||||||
]
|
]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
|
||||||
|
|
||||||
layer_hidden_states = ()
|
layer_hidden_states = ()
|
||||||
layer_attentions = ()
|
layer_attentions = ()
|
||||||
|
|
||||||
for layer_index, albert_layer in enumerate(self.albert_layers):
|
for layer_index, albert_layer in enumerate(self.albert_layers):
|
||||||
layer_output = albert_layer(
|
layer_output = albert_layer(
|
||||||
[hidden_states, attention_mask, head_mask[layer_index], output_attentions], training=training
|
hidden_states, attention_mask, head_mask[layer_index], output_attentions, training=training
|
||||||
)
|
)
|
||||||
hidden_states = layer_output[0]
|
hidden_states = layer_output[0]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
layer_attentions = layer_attentions + (layer_output[1],)
|
layer_attentions = layer_attentions + (layer_output[1],)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
layer_hidden_states = layer_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (layer_hidden_states,)
|
outputs = outputs + (layer_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (layer_attentions,)
|
outputs = outputs + (layer_attentions,)
|
||||||
# last-layer hidden state, (layer hidden states), (layer attentions)
|
# last-layer hidden state, (layer hidden states), (layer attentions)
|
||||||
return outputs
|
return outputs
|
||||||
@@ -417,13 +414,11 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
|
|||||||
for i in range(config.num_hidden_groups)
|
for i in range(config.num_hidden_groups)
|
||||||
]
|
]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
|
||||||
|
|
||||||
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
hidden_states = self.embedding_hidden_mapping_in(hidden_states)
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = (hidden_states,)
|
all_hidden_states = (hidden_states,)
|
||||||
|
|
||||||
for i in range(self.config.num_hidden_layers):
|
for i in range(self.config.num_hidden_layers):
|
||||||
@@ -434,27 +429,25 @@ class TFAlbertTransformer(tf.keras.layers.Layer):
|
|||||||
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
|
||||||
|
|
||||||
layer_group_output = self.albert_layer_groups[group_idx](
|
layer_group_output = self.albert_layer_groups[group_idx](
|
||||||
[
|
hidden_states,
|
||||||
hidden_states,
|
attention_mask,
|
||||||
attention_mask,
|
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
||||||
head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
|
output_attentions,
|
||||||
output_attentions,
|
output_hidden_states,
|
||||||
output_hidden_states,
|
|
||||||
],
|
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states = layer_group_output[0]
|
hidden_states = layer_group_output[0]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + layer_group_output[-1]
|
all_attentions = all_attentions + layer_group_output[-1]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
|
|
||||||
# last-layer hidden state, (all hidden states), (all attentions)
|
# last-layer hidden state, (all hidden states), (all attentions)
|
||||||
@@ -619,9 +612,13 @@ class TFAlbertMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
embedding_output,
|
||||||
|
extended_attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1274,7 +1271,7 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
|
||||||
flat_inputs = [
|
outputs = self.albert(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
@@ -1283,9 +1280,8 @@ class TFAlbertForMultipleChoice(TFAlbertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
outputs = self.albert(flat_inputs, training=training)
|
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ from .modeling_tf_utils import (
|
|||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -81,6 +80,7 @@ def gelu(x):
|
|||||||
Also see https://arxiv.org/abs/1606.08415
|
Also see https://arxiv.org/abs/1606.08415
|
||||||
"""
|
"""
|
||||||
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
|
cdf = 0.5 * (1.0 + tf.math.erf(x / tf.math.sqrt(2.0)))
|
||||||
|
|
||||||
return x * cdf
|
return x * cdf
|
||||||
|
|
||||||
|
|
||||||
@@ -94,6 +94,7 @@ def gelu_new(x):
|
|||||||
`x` with the GELU activation applied.
|
`x` with the GELU activation applied.
|
||||||
"""
|
"""
|
||||||
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
cdf = 0.5 * (1.0 + tf.tanh((np.sqrt(2 / np.pi) * (x + 0.044715 * tf.pow(x, 3)))))
|
||||||
|
|
||||||
return x * cdf
|
return x * cdf
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +119,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
self.initializer_range = config.initializer_range
|
self.initializer_range = config.initializer_range
|
||||||
|
|
||||||
self.position_embeddings = tf.keras.layers.Embedding(
|
self.position_embeddings = tf.keras.layers.Embedding(
|
||||||
config.max_position_embeddings,
|
config.max_position_embeddings,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
@@ -149,7 +149,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
mode="embedding",
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||||
@@ -165,15 +173,15 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, training=training)
|
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = shape_list(input_ids)
|
input_shape = shape_list(input_ids)
|
||||||
@@ -181,19 +189,22 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
input_shape = shape_list(inputs_embeds)[:-1]
|
input_shape = shape_list(inputs_embeds)[:-1]
|
||||||
|
|
||||||
seq_length = input_shape[1]
|
seq_length = input_shape[1]
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(seq_length, dtype=tf.int32)[tf.newaxis, :]
|
||||||
|
|
||||||
if token_type_ids is None:
|
if token_type_ids is None:
|
||||||
token_type_ids = tf.fill(input_shape, 0)
|
token_type_ids = tf.fill(input_shape, 0)
|
||||||
|
|
||||||
if inputs_embeds is None:
|
if inputs_embeds is None:
|
||||||
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
|
||||||
|
|
||||||
position_embeddings = self.position_embeddings(position_ids)
|
position_embeddings = self.position_embeddings(position_ids)
|
||||||
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
token_type_embeddings = self.token_type_embeddings(token_type_ids)
|
||||||
|
|
||||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings, training=training)
|
embeddings = self.dropout(embeddings, training=training)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _linear(self, inputs):
|
def _linear(self, inputs):
|
||||||
@@ -205,7 +216,6 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
"""
|
"""
|
||||||
batch_size = shape_list(inputs)[0]
|
batch_size = shape_list(inputs)[0]
|
||||||
length = shape_list(inputs)[1]
|
length = shape_list(inputs)[1]
|
||||||
|
|
||||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||||
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
logits = tf.matmul(x, self.word_embeddings, transpose_b=True)
|
||||||
|
|
||||||
@@ -215,6 +225,7 @@ class TFBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
class TFBertSelfAttention(tf.keras.layers.Layer):
|
class TFBertSelfAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
if config.hidden_size % config.num_attention_heads != 0:
|
if config.hidden_size % config.num_attention_heads != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The hidden size (%d) is not a multiple of the number of attention "
|
"The hidden size (%d) is not a multiple of the number of attention "
|
||||||
@@ -225,7 +236,6 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
|
|
||||||
self.query = tf.keras.layers.Dense(
|
self.query = tf.keras.layers.Dense(
|
||||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="query"
|
||||||
)
|
)
|
||||||
@@ -235,21 +245,18 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.value = tf.keras.layers.Dense(
|
self.value = tf.keras.layers.Dense(
|
||||||
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
self.all_head_size, kernel_initializer=get_initializer(config.initializer_range), name="value"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.attention_probs_dropout_prob)
|
||||||
|
|
||||||
def transpose_for_scores(self, x, batch_size):
|
def transpose_for_scores(self, x, batch_size):
|
||||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
|
|
||||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
batch_size = shape_list(hidden_states)[0]
|
batch_size = shape_list(hidden_states)[0]
|
||||||
mixed_query_layer = self.query(hidden_states)
|
mixed_query_layer = self.query(hidden_states)
|
||||||
mixed_key_layer = self.key(hidden_states)
|
mixed_key_layer = self.key(hidden_states)
|
||||||
mixed_value_layer = self.value(hidden_states)
|
mixed_value_layer = self.value(hidden_states)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||||
@@ -277,15 +284,11 @@ class TFBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
attention_probs = attention_probs * head_mask
|
attention_probs = attention_probs * head_mask
|
||||||
|
|
||||||
context_layer = tf.matmul(attention_probs, value_layer)
|
context_layer = tf.matmul(attention_probs, value_layer)
|
||||||
|
|
||||||
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
context_layer = tf.transpose(context_layer, perm=[0, 2, 1, 3])
|
||||||
context_layer = tf.reshape(
|
context_layer = tf.reshape(
|
||||||
context_layer, (batch_size, -1, self.all_head_size)
|
context_layer, (batch_size, -1, self.all_head_size)
|
||||||
) # (batch_size, seq_len_q, all_head_size)
|
) # (batch_size, seq_len_q, all_head_size)
|
||||||
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
outputs = (
|
|
||||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -299,12 +302,11 @@ class TFBertSelfOutput(tf.keras.layers.Layer):
|
|||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, input_tensor, training=False):
|
||||||
hidden_states, input_tensor = inputs
|
|
||||||
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -317,14 +319,13 @@ class TFBertAttention(tf.keras.layers.Layer):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, input_tensor, attention_mask, head_mask, output_attentions, training=False):
|
||||||
input_tensor, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
self_outputs = self.self_attention(
|
self_outputs = self.self_attention(
|
||||||
[input_tensor, attention_mask, head_mask, output_attentions], training=training
|
input_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||||
)
|
)
|
||||||
attention_output = self.dense_output([self_outputs[0], input_tensor], training=training)
|
attention_output = self.dense_output(self_outputs[0], input_tensor, training=training)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -334,6 +335,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.intermediate_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(config.hidden_act, str):
|
if isinstance(config.hidden_act, str):
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
@@ -342,6 +344,7 @@ class TFBertIntermediate(tf.keras.layers.Layer):
|
|||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.intermediate_act_fn(hidden_states)
|
hidden_states = self.intermediate_act_fn(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -354,12 +357,11 @@ class TFBertOutput(tf.keras.layers.Layer):
|
|||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, input_tensor, training=False):
|
||||||
hidden_states, input_tensor = inputs
|
|
||||||
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
hidden_states = self.LayerNorm(hidden_states + input_tensor)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -370,16 +372,15 @@ class TFBertLayer(tf.keras.layers.Layer):
|
|||||||
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
self.intermediate = TFBertIntermediate(config, name="intermediate")
|
||||||
self.bert_output = TFBertOutput(config, name="output")
|
self.bert_output = TFBertOutput(config, name="output")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
[hidden_states, attention_mask, head_mask, output_attentions], training=training
|
hidden_states, attention_mask, head_mask, output_attentions, training=training
|
||||||
)
|
)
|
||||||
attention_output = attention_outputs[0]
|
attention_output = attention_outputs[0]
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.bert_output([intermediate_output, attention_output], training=training)
|
layer_output = self.bert_output(intermediate_output, attention_output, training=training)
|
||||||
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -388,32 +389,34 @@ class TFBertEncoder(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
self.layer = [TFBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
|
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
|
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||||
)
|
)
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
|
||||||
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
|
||||||
|
if output_attentions:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
|
|
||||||
return outputs # outputs, (hidden states), (attentions)
|
return outputs # outputs, (hidden states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@@ -432,6 +435,7 @@ class TFBertPooler(tf.keras.layers.Layer):
|
|||||||
# to the first token.
|
# to the first token.
|
||||||
first_token_tensor = hidden_states[:, 0]
|
first_token_tensor = hidden_states[:, 0]
|
||||||
pooled_output = self.dense(first_token_tensor)
|
pooled_output = self.dense(first_token_tensor)
|
||||||
|
|
||||||
return pooled_output
|
return pooled_output
|
||||||
|
|
||||||
|
|
||||||
@@ -441,16 +445,19 @@ class TFBertPredictionHeadTransform(tf.keras.layers.Layer):
|
|||||||
self.dense = tf.keras.layers.Dense(
|
self.dense = tf.keras.layers.Dense(
|
||||||
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
config.hidden_size, kernel_initializer=get_initializer(config.initializer_range), name="dense"
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(config.hidden_act, str):
|
if isinstance(config.hidden_act, str):
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
else:
|
else:
|
||||||
self.transform_act_fn = config.hidden_act
|
self.transform_act_fn = config.hidden_act
|
||||||
|
|
||||||
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
self.LayerNorm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="LayerNorm")
|
||||||
|
|
||||||
def call(self, hidden_states):
|
def call(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
hidden_states = self.transform_act_fn(hidden_states)
|
hidden_states = self.transform_act_fn(hidden_states)
|
||||||
hidden_states = self.LayerNorm(hidden_states)
|
hidden_states = self.LayerNorm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -472,6 +479,7 @@ class TFBertLMPredictionHead(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.transform(hidden_states)
|
hidden_states = self.transform(hidden_states)
|
||||||
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
hidden_states = self.input_embeddings(hidden_states, mode="linear")
|
||||||
hidden_states = hidden_states + self.bias
|
hidden_states = hidden_states + self.bias
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -482,6 +490,7 @@ class TFBertMLMHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, sequence_output):
|
def call(self, sequence_output):
|
||||||
prediction_scores = self.predictions(sequence_output)
|
prediction_scores = self.predictions(sequence_output)
|
||||||
|
|
||||||
return prediction_scores
|
return prediction_scores
|
||||||
|
|
||||||
|
|
||||||
@@ -494,6 +503,7 @@ class TFBertNSPHead(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
def call(self, pooled_output):
|
def call(self, pooled_output):
|
||||||
seq_relationship_score = self.seq_relationship(pooled_output)
|
seq_relationship_score = self.seq_relationship(pooled_output)
|
||||||
|
|
||||||
return seq_relationship_score
|
return seq_relationship_score
|
||||||
|
|
||||||
|
|
||||||
@@ -507,7 +517,6 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
self.initializer_range = config.initializer_range
|
self.initializer_range = config.initializer_range
|
||||||
self.output_attentions = config.output_attentions
|
self.output_attentions = config.output_attentions
|
||||||
self.output_hidden_states = config.output_hidden_states
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
|
||||||
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
self.embeddings = TFBertEmbeddings(config, name="embeddings")
|
||||||
self.encoder = TFBertEncoder(config, name="encoder")
|
self.encoder = TFBertEncoder(config, name="encoder")
|
||||||
self.pooler = TFBertPooler(config, name="pooler")
|
self.pooler = TFBertPooler(config, name="pooler")
|
||||||
@@ -605,18 +614,22 @@ class TFBertMainLayer(tf.keras.layers.Layer):
|
|||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
||||||
|
|
||||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
embedding_output,
|
||||||
|
extended_attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
sequence_output = encoder_outputs[0]
|
sequence_output = encoder_outputs[0]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
|
|
||||||
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
outputs = (sequence_output, pooled_output,) + encoder_outputs[
|
||||||
1:
|
1:
|
||||||
] # add hidden_states and attentions if they are here
|
] # add hidden_states and attentions if they are here
|
||||||
|
|
||||||
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
return outputs # sequence_output, pooled_output, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|
||||||
@@ -1211,8 +1224,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
if inputs_embeds is not None
|
if inputs_embeds is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
outputs = self.bert(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
@@ -1221,16 +1233,12 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_inputs_embeds,
|
flat_inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
outputs = self.bert(flat_inputs, training=training)
|
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=training)
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -27,7 +27,6 @@ from .modeling_tf_utils import (
|
|||||||
TFCausalLanguageModelingLoss,
|
TFCausalLanguageModelingLoss,
|
||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
)
|
)
|
||||||
@@ -87,10 +86,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
|
|||||||
|
|
||||||
|
|
||||||
class TFMultiHeadAttention(tf.keras.layers.Layer):
|
class TFMultiHeadAttention(tf.keras.layers.Layer):
|
||||||
def __init__(self, d_model_size, num_heads, **kwargs):
|
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.d_model_size = d_model_size
|
self.d_model_size = d_model_size
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
self.depth = int(d_model_size / self.num_heads)
|
self.depth = int(d_model_size / self.num_heads)
|
||||||
|
|
||||||
@@ -104,8 +104,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
|
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
|
||||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||||
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
|
||||||
batch_size = shape_list(q)[0]
|
batch_size = shape_list(q)[0]
|
||||||
|
|
||||||
q = self.Wq(q)
|
q = self.Wq(q)
|
||||||
@@ -121,10 +120,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
k = tf.concat((past_key, k), axis=-2)
|
k = tf.concat((past_key, k), axis=-2)
|
||||||
v = tf.concat((past_value, v), axis=-2)
|
v = tf.concat((past_value, v), axis=-2)
|
||||||
|
|
||||||
# to cope with keras serialization
|
if use_cache:
|
||||||
use_cache = cast_bool_to_primitive(use_cache, True)
|
|
||||||
|
|
||||||
if use_cache is True:
|
|
||||||
present = tf.stack((k, v), axis=0)
|
present = tf.stack((k, v), axis=0)
|
||||||
else:
|
else:
|
||||||
present = (None,)
|
present = (None,)
|
||||||
@@ -134,10 +130,11 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
attn = output[1]
|
attn = output[1]
|
||||||
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
|
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
|
||||||
output = self.dense(original_size_attention)
|
output = self.dense(original_size_attention)
|
||||||
|
|
||||||
outputs = (output, present)
|
outputs = (output, present)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
|
||||||
|
if output_attentions:
|
||||||
outputs = outputs + (attn,)
|
outputs = outputs + (attn,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
@@ -156,10 +153,16 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class TFEncoderLayer(tf.keras.layers.Layer):
|
class TFEncoderLayer(tf.keras.layers.Layer):
|
||||||
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
|
def __init__(
|
||||||
|
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
|
||||||
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
|
self.multi_head_attention = TFMultiHeadAttention(
|
||||||
|
d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
|
||||||
|
)
|
||||||
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
|
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
|
||||||
|
|
||||||
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
|
||||||
@@ -168,11 +171,18 @@ class TFEncoderLayer(tf.keras.layers.Layer):
|
|||||||
self.dropout1 = tf.keras.layers.Dropout(rate)
|
self.dropout1 = tf.keras.layers.Dropout(rate)
|
||||||
self.dropout2 = tf.keras.layers.Dropout(rate)
|
self.dropout2 = tf.keras.layers.Dropout(rate)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||||
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
|
||||||
normed = self.layernorm1(x)
|
normed = self.layernorm1(x)
|
||||||
attn_outputs = self.multi_head_attention(
|
attn_outputs = self.multi_head_attention(
|
||||||
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
|
normed,
|
||||||
|
normed,
|
||||||
|
normed,
|
||||||
|
mask,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
use_cache,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
attn_output = attn_outputs[0]
|
attn_output = attn_outputs[0]
|
||||||
@@ -215,6 +225,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
config.dff,
|
config.dff,
|
||||||
config.resid_pdrop,
|
config.resid_pdrop,
|
||||||
config.layer_norm_epsilon,
|
config.layer_norm_epsilon,
|
||||||
|
self.output_attentions,
|
||||||
name="h_._{}".format(i),
|
name="h_._{}".format(i),
|
||||||
)
|
)
|
||||||
for i in range(config.n_layer)
|
for i in range(config.n_layer)
|
||||||
@@ -367,31 +378,37 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
|
|||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = []
|
all_attentions = []
|
||||||
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
for i, (h, layer_past) in enumerate(zip(self.h, past)):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||||
outputs = h(
|
outputs = h(
|
||||||
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
|
hidden_states,
|
||||||
|
mask,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
use_cache,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
hidden_states, present = outputs[:2]
|
hidden_states, present = outputs[:2]
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
presents = presents + (present,)
|
presents = presents + (present,)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions.append(outputs[2])
|
all_attentions.append(outputs[2])
|
||||||
|
|
||||||
hidden_states = self.layernorm(hidden_states)
|
hidden_states = self.layernorm(hidden_states)
|
||||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
outputs = outputs + (presents,)
|
outputs = outputs + (presents,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||||
|
|||||||
@@ -37,7 +37,6 @@ from .modeling_tf_utils import (
|
|||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -114,7 +113,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, inputs_embeds=None, mode="embedding", training=False):
|
def call(self, input_ids=None, position_ids=None, inputs_embeds=None, mode="embedding", training=False):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
|
inputs: list of two int64 tensors with shape [batch_size, length]: (input_ids, position_ids)
|
||||||
@@ -130,13 +129,13 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, inputs_embeds=inputs_embeds, training=training)
|
return self._embedding(input_ids, position_ids, inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, inputs_embeds=None, training=False):
|
def _embedding(self, input_ids, position_ids, inputs_embeds, training=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -148,11 +147,7 @@ class TFEmbeddings(tf.keras.layers.Layer):
|
|||||||
embeddings: tf.Tensor(bs, max_seq_length, dim)
|
embeddings: tf.Tensor(bs, max_seq_length, dim)
|
||||||
The embedded tokens (plus position embeddings, no token_type embeddings)
|
The embedded tokens (plus position embeddings, no token_type embeddings)
|
||||||
"""
|
"""
|
||||||
if not isinstance(inputs, (tuple, list)):
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
input_ids = inputs
|
|
||||||
position_ids = None
|
|
||||||
else:
|
|
||||||
input_ids, position_ids = inputs
|
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
seq_length = shape_list(input_ids)[1]
|
seq_length = shape_list(input_ids)[1]
|
||||||
@@ -194,6 +189,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
self.n_heads = config.n_heads
|
self.n_heads = config.n_heads
|
||||||
self.dim = config.dim
|
self.dim = config.dim
|
||||||
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
self.dropout = tf.keras.layers.Dropout(config.attention_dropout)
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
|
assert self.dim % self.n_heads == 0, f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
|
||||||
|
|
||||||
@@ -215,7 +211,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, query, key, value, mask, head_mask, output_attentions, training=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -231,7 +227,6 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
context: tf.Tensor(bs, seq_length, dim)
|
context: tf.Tensor(bs, seq_length, dim)
|
||||||
Contextualized layer. Optional: only if `output_attentions=True`
|
Contextualized layer. Optional: only if `output_attentions=True`
|
||||||
"""
|
"""
|
||||||
query, key, value, mask, head_mask, output_attentions = inputs
|
|
||||||
bs, q_length, dim = shape_list(query)
|
bs, q_length, dim = shape_list(query)
|
||||||
k_length = shape_list(key)[1]
|
k_length = shape_list(key)[1]
|
||||||
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
|
||||||
@@ -270,7 +265,7 @@ class TFMultiHeadSelfAttention(tf.keras.layers.Layer):
|
|||||||
context = unshape(context) # (bs, q_length, dim)
|
context = unshape(context) # (bs, q_length, dim)
|
||||||
context = self.out_lin(context) # (bs, q_length, dim)
|
context = self.out_lin(context) # (bs, q_length, dim)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
return (context, weights)
|
return (context, weights)
|
||||||
else:
|
else:
|
||||||
return (context,)
|
return (context,)
|
||||||
@@ -310,6 +305,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||||||
self.hidden_dim = config.hidden_dim
|
self.hidden_dim = config.hidden_dim
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
self.activation = config.activation
|
self.activation = config.activation
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
assert (
|
assert (
|
||||||
config.dim % config.n_heads == 0
|
config.dim % config.n_heads == 0
|
||||||
@@ -321,7 +317,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||||||
self.ffn = TFFFN(config, name="ffn")
|
self.ffn = TFFFN(config, name="ffn")
|
||||||
self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
|
self.output_layer_norm = tf.keras.layers.LayerNormalization(epsilon=1e-12, name="output_layer_norm")
|
||||||
|
|
||||||
def call(self, inputs, training=False): # removed: src_enc=None, src_len=None
|
def call(self, x, attn_mask, head_mask, output_attentions, training=False): # removed: src_enc=None, src_len=None
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -335,11 +331,9 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||||||
ffn_output: tf.Tensor(bs, seq_length, dim)
|
ffn_output: tf.Tensor(bs, seq_length, dim)
|
||||||
The output of the transformer block contextualization.
|
The output of the transformer block contextualization.
|
||||||
"""
|
"""
|
||||||
x, attn_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
# Self-Attention
|
# Self-Attention
|
||||||
sa_output = self.attention([x, x, x, attn_mask, head_mask, output_attentions], training=training)
|
sa_output = self.attention(x, x, x, attn_mask, head_mask, output_attentions, training=training)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
sa_output, sa_weights = sa_output # (bs, seq_length, dim), (bs, n_heads, seq_length, seq_length)
|
||||||
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
else: # To handle these `output_attention` or `output_hidden_states` cases returning tuples
|
||||||
# assert type(sa_output) == tuple
|
# assert type(sa_output) == tuple
|
||||||
@@ -351,7 +345,7 @@ class TFTransformerBlock(tf.keras.layers.Layer):
|
|||||||
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
ffn_output = self.output_layer_norm(ffn_output + sa_output) # (bs, seq_length, dim)
|
||||||
|
|
||||||
output = (ffn_output,)
|
output = (ffn_output,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
output = (sa_weights,) + output
|
output = (sa_weights,) + output
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@@ -360,10 +354,12 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.n_layers = config.n_layers
|
self.n_layers = config.n_layers
|
||||||
|
self.output_hidden_states = config.output_hidden_states
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
|
self.layer = [TFTransformerBlock(config, name="layer_._{}".format(i)) for i in range(config.n_layers)]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, attn_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||||
"""
|
"""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -383,34 +379,32 @@ class TFTransformer(tf.keras.layers.Layer):
|
|||||||
Tuple of length n_layers with the attention weights from each layer
|
Tuple of length n_layers with the attention weights from each layer
|
||||||
Optional: only if output_attentions=True
|
Optional: only if output_attentions=True
|
||||||
"""
|
"""
|
||||||
x, attn_mask, head_mask, output_attentions, output_hidden_states = inputs
|
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
|
|
||||||
hidden_state = x
|
hidden_state = x
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||||
|
|
||||||
layer_outputs = layer_module([hidden_state, attn_mask, head_mask[i], output_attentions], training=training)
|
layer_outputs = layer_module(hidden_state, attn_mask, head_mask[i], output_attentions, training=training)
|
||||||
hidden_state = layer_outputs[-1]
|
hidden_state = layer_outputs[-1]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
assert len(layer_outputs) == 2, f"Incorrect number of outputs {len(layer_outputs)} instead of 2"
|
assert len(layer_outputs) == 2
|
||||||
attentions = layer_outputs[0]
|
attentions = layer_outputs[0]
|
||||||
all_attentions = all_attentions + (attentions,)
|
all_attentions = all_attentions + (attentions,)
|
||||||
else:
|
else:
|
||||||
assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
|
assert len(layer_outputs) == 1, f"Incorrect number of outputs {len(layer_outputs)} instead of 1"
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_state,)
|
all_hidden_states = all_hidden_states + (hidden_state,)
|
||||||
|
|
||||||
outputs = (hidden_state,)
|
outputs = (hidden_state,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
@@ -481,6 +475,7 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
if attention_mask is None:
|
if attention_mask is None:
|
||||||
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
attention_mask = tf.ones(input_shape) # (bs, seq_length)
|
||||||
|
|
||||||
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
attention_mask = tf.cast(attention_mask, dtype=tf.float32)
|
||||||
|
|
||||||
# Prepare head mask if needed
|
# Prepare head mask if needed
|
||||||
@@ -491,11 +486,12 @@ class TFDistilBertMainLayer(tf.keras.layers.Layer):
|
|||||||
if head_mask is not None:
|
if head_mask is not None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
|
|
||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
|
|
||||||
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
embedding_output = self.embeddings(input_ids, inputs_embeds=inputs_embeds) # (bs, seq_length, dim)
|
||||||
tfmr_output = self.transformer(
|
tfmr_output = self.transformer(
|
||||||
[embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states], training=training
|
embedding_output, attention_mask, head_mask, output_attentions, output_hidden_states, training=training
|
||||||
)
|
)
|
||||||
|
|
||||||
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
return tfmr_output # last-layer hidden-state, (all hidden_states), (all attentions)
|
||||||
@@ -986,24 +982,21 @@ class TFDistilBertForMultipleChoice(TFDistilBertPreTrainedModel, TFMultipleChoic
|
|||||||
if inputs_embeds is not None
|
if inputs_embeds is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
distilbert_output = self.distilbert(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
head_mask,
|
head_mask,
|
||||||
flat_inputs_embeds,
|
flat_inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
distilbert_output = self.distilbert(flat_inputs, training=training)
|
|
||||||
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
hidden_state = distilbert_output[0] # (bs, seq_len, dim)
|
||||||
pooled_output = hidden_state[:, 0] # (bs, dim)
|
pooled_output = hidden_state[:, 0] # (bs, dim)
|
||||||
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
pooled_output = self.pre_classifier(pooled_output) # (bs, dim)
|
||||||
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
pooled_output = self.dropout(pooled_output, training=training) # (bs, dim)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + distilbert_output[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ import logging
|
|||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
from .configuration_electra import ElectraConfig
|
from transformers import ElectraConfig
|
||||||
|
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
from .modeling_tf_bert import ACT2FN, TFBertEncoder, TFBertPreTrainedModel
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import (
|
||||||
@@ -71,7 +72,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
mode="embedding",
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||||
@@ -87,15 +96,15 @@ class TFElectraEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, training=training)
|
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = shape_list(input_ids)
|
input_shape = shape_list(input_ids)
|
||||||
@@ -289,13 +298,17 @@ class TFElectraMainLayer(TFElectraPreTrainedModel):
|
|||||||
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
|
||||||
head_mask = self.get_head_mask(head_mask)
|
head_mask = self.get_head_mask(head_mask)
|
||||||
|
|
||||||
hidden_states = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
|
|
||||||
if hasattr(self, "embeddings_project"):
|
if hasattr(self, "embeddings_project"):
|
||||||
hidden_states = self.embeddings_project(hidden_states, training=training)
|
hidden_states = self.embeddings_project(hidden_states, training=training)
|
||||||
|
|
||||||
hidden_states = self.encoder(
|
hidden_states = self.encoder(
|
||||||
[hidden_states, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
hidden_states,
|
||||||
|
extended_attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
from .configuration_flaubert import FlaubertConfig
|
from .configuration_flaubert import FlaubertConfig
|
||||||
from .file_utils import add_start_docstrings
|
from .file_utils import add_start_docstrings
|
||||||
from .modeling_tf_utils import cast_bool_to_primitive, keras_serializable, shape_list
|
from .modeling_tf_utils import keras_serializable, shape_list
|
||||||
from .modeling_tf_xlm import (
|
from .modeling_tf_xlm import (
|
||||||
TFXLMForMultipleChoice,
|
TFXLMForMultipleChoice,
|
||||||
TFXLMForQuestionAnsweringSimple,
|
TFXLMForQuestionAnsweringSimple,
|
||||||
@@ -274,10 +274,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||||||
# self attention
|
# self attention
|
||||||
if not self.pre_norm:
|
if not self.pre_norm:
|
||||||
attn_outputs = self.attentions[i](
|
attn_outputs = self.attentions[i](
|
||||||
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
|
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||||
)
|
)
|
||||||
attn = attn_outputs[0]
|
attn = attn_outputs[0]
|
||||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
if output_attentions:
|
||||||
attentions = attentions + (attn_outputs[1],)
|
attentions = attentions + (attn_outputs[1],)
|
||||||
attn = self.dropout(attn, training=training)
|
attn = self.dropout(attn, training=training)
|
||||||
tensor = tensor + attn
|
tensor = tensor + attn
|
||||||
@@ -285,10 +285,10 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||||||
else:
|
else:
|
||||||
tensor_normalized = self.layer_norm1[i](tensor)
|
tensor_normalized = self.layer_norm1[i](tensor)
|
||||||
attn_outputs = self.attentions[i](
|
attn_outputs = self.attentions[i](
|
||||||
[tensor_normalized, attn_mask, None, cache, head_mask[i]], training=training
|
tensor_normalized, attn_mask, None, cache, head_mask[i], training=training
|
||||||
)
|
)
|
||||||
attn = attn_outputs[0]
|
attn = attn_outputs[0]
|
||||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
if output_attentions:
|
||||||
attentions = attentions + (attn_outputs[1],)
|
attentions = attentions + (attn_outputs[1],)
|
||||||
attn = self.dropout(attn, training=training)
|
attn = self.dropout(attn, training=training)
|
||||||
tensor = tensor + attn
|
tensor = tensor + attn
|
||||||
@@ -311,7 +311,7 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||||||
tensor = tensor * mask[..., tf.newaxis]
|
tensor = tensor * mask[..., tf.newaxis]
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states = hidden_states + (tensor,)
|
hidden_states = hidden_states + (tensor,)
|
||||||
|
|
||||||
# update cache length
|
# update cache length
|
||||||
@@ -322,9 +322,9 @@ class TFFlaubertMainLayer(TFXLMMainLayer):
|
|||||||
# tensor = tensor.transpose(0, 1)
|
# tensor = tensor.transpose(0, 1)
|
||||||
|
|
||||||
outputs = (tensor,)
|
outputs = (tensor,)
|
||||||
if cast_bool_to_primitive(output_hidden_states, self.output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (hidden_states,)
|
outputs = outputs + (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions, self.output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (attentions,)
|
outputs = outputs + (attentions,)
|
||||||
return outputs # outputs, (hidden_states), (attentions)
|
return outputs # outputs, (hidden_states), (attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from .modeling_tf_utils import (
|
|||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -75,6 +74,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
||||||
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
||||||
@@ -95,8 +95,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
m = i >= j - ns + nd
|
m = i >= j - ns + nd
|
||||||
return tf.cast(m, dtype)
|
return tf.cast(m, dtype)
|
||||||
|
|
||||||
def _attn(self, inputs, training=False):
|
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||||
q, k, v, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
# q, k, v have shape [batch, heads, sequence, features]
|
# q, k, v have shape [batch, heads, sequence, features]
|
||||||
w = tf.matmul(q, k, transpose_b=True)
|
w = tf.matmul(q, k, transpose_b=True)
|
||||||
if self.scale:
|
if self.scale:
|
||||||
@@ -121,7 +120,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
w = w * head_mask
|
w = w * head_mask
|
||||||
|
|
||||||
outputs = [tf.matmul(w, v)]
|
outputs = [tf.matmul(w, v)]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs.append(w)
|
outputs.append(w)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -137,9 +136,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, new_x_shape)
|
x = tf.reshape(x, new_x_shape)
|
||||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||||
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
|
||||||
|
|
||||||
x = self.c_attn(x)
|
x = self.c_attn(x)
|
||||||
query, key, value = tf.split(x, 3, axis=2)
|
query, key, value = tf.split(x, 3, axis=2)
|
||||||
query = self.split_heads(query)
|
query = self.split_heads(query)
|
||||||
@@ -151,12 +148,12 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
value = tf.concat([past_value, value], axis=-2)
|
value = tf.concat([past_value, value], axis=-2)
|
||||||
|
|
||||||
# to cope with keras serialization
|
# to cope with keras serialization
|
||||||
if cast_bool_to_primitive(use_cache, True) is True:
|
if use_cache:
|
||||||
present = tf.stack([key, value], axis=0)
|
present = tf.stack([key, value], axis=0)
|
||||||
else:
|
else:
|
||||||
present = (None,)
|
present = (None,)
|
||||||
|
|
||||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
|
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
|
||||||
a = attn_outputs[0]
|
a = attn_outputs[0]
|
||||||
|
|
||||||
a = self.merge_heads(a)
|
a = self.merge_heads(a)
|
||||||
@@ -192,12 +189,10 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
|
||||||
x, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
|
|
||||||
|
|
||||||
a = self.ln_1(x)
|
a = self.ln_1(x)
|
||||||
output_attn = self.attn(
|
output_attn = self.attn(
|
||||||
[a, layer_past, attention_mask, head_mask, use_cache, output_attentions], training=training
|
a, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=training
|
||||||
)
|
)
|
||||||
a = output_attn[0] # output_attn: a, present, (attentions)
|
a = output_attn[0] # output_attn: a, present, (attentions)
|
||||||
x = x + a
|
x = x + a
|
||||||
@@ -223,6 +218,8 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
self.num_hidden_layers = config.n_layer
|
self.num_hidden_layers = config.n_layer
|
||||||
self.vocab_size = config.vocab_size
|
self.vocab_size = config.vocab_size
|
||||||
self.n_embd = config.n_embd
|
self.n_embd = config.n_embd
|
||||||
|
self.output_hidden_states = self.output_hidden_states
|
||||||
|
self.output_attentions = self.output_attentions
|
||||||
|
|
||||||
self.wte = TFSharedEmbeddings(
|
self.wte = TFSharedEmbeddings(
|
||||||
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
config.vocab_size, config.hidden_size, initializer_range=config.initializer_range, name="wte"
|
||||||
@@ -362,34 +359,39 @@ class TFGPT2MainLayer(tf.keras.layers.Layer):
|
|||||||
all_attentions = []
|
all_attentions = []
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
for i, (block, layer_past) in enumerate(zip(self.h, past)):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||||
|
|
||||||
outputs = block(
|
outputs = block(
|
||||||
[hidden_states, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
|
hidden_states,
|
||||||
|
layer_past,
|
||||||
|
attention_mask,
|
||||||
|
head_mask[i],
|
||||||
|
use_cache,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, present = outputs[:2]
|
hidden_states, present = outputs[:2]
|
||||||
presents = presents + (present,)
|
presents = presents + (present,)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions.append(outputs[2])
|
all_attentions.append(outputs[2])
|
||||||
|
|
||||||
hidden_states = self.ln_f(hidden_states)
|
hidden_states = self.ln_f(hidden_states)
|
||||||
|
|
||||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
if use_cache is True:
|
if use_cache:
|
||||||
outputs = outputs + (presents,)
|
outputs = outputs + (presents,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||||
@@ -738,13 +740,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||||
|
|
||||||
seq_length = input_shapes[-1]
|
seq_length = input_shapes[-1]
|
||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
past,
|
past,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
@@ -755,18 +755,13 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
|
|||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||||
|
|
||||||
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
lm_logits = self.transformer.wte(hidden_states, mode="linear")
|
||||||
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||||
|
|
||||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||||
|
|
||||||
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||||
|
|
||||||
return outputs # lm logits, mc logits, presents, (all hidden_states), (attentions)
|
return outputs # lm logits, mc logits, presents, (all hidden_states), (attentions)
|
||||||
|
|||||||
@@ -35,7 +35,6 @@ from .modeling_tf_utils import (
|
|||||||
TFQuestionAnsweringLoss,
|
TFQuestionAnsweringLoss,
|
||||||
TFSequenceClassificationLoss,
|
TFSequenceClassificationLoss,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -130,7 +129,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
super().build(input_shape)
|
super().build(input_shape)
|
||||||
|
|
||||||
def call(self, inputs, mode="embedding", training=False):
|
def call(
|
||||||
|
self,
|
||||||
|
input_ids=None,
|
||||||
|
position_ids=None,
|
||||||
|
token_type_ids=None,
|
||||||
|
inputs_embeds=None,
|
||||||
|
mode="embedding",
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
"""Get token embeddings of inputs.
|
"""Get token embeddings of inputs.
|
||||||
Args:
|
Args:
|
||||||
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
inputs: list of three int64 tensors with shape [batch_size, length]: (input_ids, position_ids, token_type_ids)
|
||||||
@@ -146,15 +153,15 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
https://github.com/tensorflow/models/blob/a009f4fb9d2fc4949e32192a944688925ef78659/official/transformer/v2/embedding_layer.py#L24
|
||||||
"""
|
"""
|
||||||
if mode == "embedding":
|
if mode == "embedding":
|
||||||
return self._embedding(inputs, training=training)
|
return self._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
elif mode == "linear":
|
elif mode == "linear":
|
||||||
return self._linear(inputs)
|
return self._linear(input_ids)
|
||||||
else:
|
else:
|
||||||
raise ValueError("mode {} is not valid.".format(mode))
|
raise ValueError("mode {} is not valid.".format(mode))
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
|
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
input_shape = shape_list(input_ids)
|
input_shape = shape_list(input_ids)
|
||||||
@@ -196,6 +203,7 @@ class TFMobileBertEmbeddings(tf.keras.layers.Layer):
|
|||||||
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
|
||||||
embeddings = self.LayerNorm(embeddings)
|
embeddings = self.LayerNorm(embeddings)
|
||||||
embeddings = self.dropout(embeddings, training=training)
|
embeddings = self.dropout(embeddings, training=training)
|
||||||
|
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
def _linear(self, inputs):
|
def _linear(self, inputs):
|
||||||
@@ -224,6 +232,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.num_attention_heads = config.num_attention_heads
|
self.num_attention_heads = config.num_attention_heads
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
assert config.hidden_size % config.num_attention_heads == 0
|
assert config.hidden_size % config.num_attention_heads == 0
|
||||||
self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
|
self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
|
||||||
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
self.all_head_size = self.num_attention_heads * self.attention_head_size
|
||||||
@@ -244,14 +253,13 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
x = tf.reshape(x, (batch_size, -1, self.num_attention_heads, self.attention_head_size))
|
||||||
return tf.transpose(x, perm=[0, 2, 1, 3])
|
return tf.transpose(x, perm=[0, 2, 1, 3])
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(
|
||||||
query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions = inputs
|
self, query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=False
|
||||||
|
):
|
||||||
batch_size = shape_list(attention_mask)[0]
|
batch_size = shape_list(attention_mask)[0]
|
||||||
mixed_query_layer = self.query(query_tensor)
|
mixed_query_layer = self.query(query_tensor)
|
||||||
mixed_key_layer = self.key(key_tensor)
|
mixed_key_layer = self.key(key_tensor)
|
||||||
mixed_value_layer = self.value(value_tensor)
|
mixed_value_layer = self.value(value_tensor)
|
||||||
|
|
||||||
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
query_layer = self.transpose_for_scores(mixed_query_layer, batch_size)
|
||||||
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
key_layer = self.transpose_for_scores(mixed_key_layer, batch_size)
|
||||||
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
value_layer = self.transpose_for_scores(mixed_value_layer, batch_size)
|
||||||
@@ -285,9 +293,7 @@ class TFMobileBertSelfAttention(tf.keras.layers.Layer):
|
|||||||
context_layer, (batch_size, -1, self.all_head_size)
|
context_layer, (batch_size, -1, self.all_head_size)
|
||||||
) # (batch_size, seq_len_q, all_head_size)
|
) # (batch_size, seq_len_q, all_head_size)
|
||||||
|
|
||||||
outputs = (
|
outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
|
||||||
(context_layer, attention_probs) if cast_bool_to_primitive(output_attentions) is True else (context_layer,)
|
|
||||||
)
|
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -305,8 +311,7 @@ class TFMobileBertSelfOutput(tf.keras.layers.Layer):
|
|||||||
if not self.use_bottleneck:
|
if not self.use_bottleneck:
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, residual_tensor, training=False):
|
||||||
hidden_states, residual_tensor = inputs
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
if not self.use_bottleneck:
|
if not self.use_bottleneck:
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
@@ -323,13 +328,22 @@ class TFMobileBertAttention(tf.keras.layers.Layer):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(
|
||||||
query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions = inputs
|
self,
|
||||||
|
query_tensor,
|
||||||
|
key_tensor,
|
||||||
|
value_tensor,
|
||||||
|
layer_input,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
self_outputs = self.self(
|
self_outputs = self.self(
|
||||||
[query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions], training=training
|
query_tensor, key_tensor, value_tensor, attention_mask, head_mask, output_attentions, training=training
|
||||||
)
|
)
|
||||||
attention_output = self.mobilebert_output([self_outputs[0], layer_input], training=training)
|
|
||||||
|
attention_output = self.mobilebert_output(self_outputs[0], layer_input, training=training)
|
||||||
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -349,8 +363,7 @@ class TFOutputBottleneck(tf.keras.layers.Layer):
|
|||||||
)
|
)
|
||||||
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
self.dropout = tf.keras.layers.Dropout(config.hidden_dropout_prob)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, residual_tensor, training=False):
|
||||||
hidden_states, residual_tensor = inputs
|
|
||||||
layer_outputs = self.dense(hidden_states)
|
layer_outputs = self.dense(hidden_states)
|
||||||
layer_outputs = self.dropout(layer_outputs, training=training)
|
layer_outputs = self.dropout(layer_outputs, training=training)
|
||||||
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
|
layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
|
||||||
@@ -372,16 +385,14 @@ class TFMobileBertOutput(tf.keras.layers.Layer):
|
|||||||
else:
|
else:
|
||||||
self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
|
self.bottleneck = TFOutputBottleneck(config, name="bottleneck")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, residual_tensor_1, residual_tensor_2, training=False):
|
||||||
hidden_states, residual_tensor_1, residual_tensor_2 = inputs
|
|
||||||
|
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
if not self.use_bottleneck:
|
if not self.use_bottleneck:
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
hidden_states = self.LayerNorm(hidden_states + residual_tensor_1)
|
||||||
hidden_states = self.bottleneck([hidden_states, residual_tensor_2])
|
hidden_states = self.bottleneck(hidden_states, residual_tensor_2)
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -466,7 +477,6 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.use_bottleneck = config.use_bottleneck
|
self.use_bottleneck = config.use_bottleneck
|
||||||
self.num_feedforward_networks = config.num_feedforward_networks
|
self.num_feedforward_networks = config.num_feedforward_networks
|
||||||
|
|
||||||
self.attention = TFMobileBertAttention(config, name="attention")
|
self.attention = TFMobileBertAttention(config, name="attention")
|
||||||
self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
|
self.intermediate = TFMobileBertIntermediate(config, name="intermediate")
|
||||||
self.mobilebert_output = TFMobileBertOutput(config, name="output")
|
self.mobilebert_output = TFMobileBertOutput(config, name="output")
|
||||||
@@ -478,16 +488,20 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||||||
TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1)
|
TFFFNLayer(config, name="ffn.{}".format(i)) for i in range(config.num_feedforward_networks - 1)
|
||||||
]
|
]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
if self.use_bottleneck:
|
if self.use_bottleneck:
|
||||||
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
|
query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
|
||||||
else:
|
else:
|
||||||
query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
|
query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
|
||||||
|
|
||||||
attention_outputs = self.attention(
|
attention_outputs = self.attention(
|
||||||
[query_tensor, key_tensor, value_tensor, layer_input, attention_mask, head_mask, output_attentions],
|
query_tensor,
|
||||||
|
key_tensor,
|
||||||
|
value_tensor,
|
||||||
|
layer_input,
|
||||||
|
attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -500,48 +514,57 @@ class TFMobileBertLayer(tf.keras.layers.Layer):
|
|||||||
s += (attention_output,)
|
s += (attention_output,)
|
||||||
|
|
||||||
intermediate_output = self.intermediate(attention_output)
|
intermediate_output = self.intermediate(attention_output)
|
||||||
layer_output = self.mobilebert_output(
|
layer_output = self.mobilebert_output(intermediate_output, attention_output, hidden_states, training=training)
|
||||||
[intermediate_output, attention_output, hidden_states], training=training
|
|
||||||
)
|
|
||||||
outputs = (
|
outputs = (
|
||||||
(layer_output,)
|
(layer_output,)
|
||||||
+ attention_outputs[1:]
|
+ attention_outputs[1:]
|
||||||
+ (0, query_tensor, key_tensor, value_tensor, layer_input, attention_output, intermediate_output)
|
+ (
|
||||||
|
tf.constant(0),
|
||||||
|
query_tensor,
|
||||||
|
key_tensor,
|
||||||
|
value_tensor,
|
||||||
|
layer_input,
|
||||||
|
attention_output,
|
||||||
|
intermediate_output,
|
||||||
|
)
|
||||||
+ s
|
+ s
|
||||||
) # add attentions if we output them
|
) # add attentions if we output them
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
class TFMobileBertEncoder(tf.keras.layers.Layer):
|
class TFMobileBertEncoder(tf.keras.layers.Layer):
|
||||||
def __init__(self, config, **kwargs):
|
def __init__(self, config, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
self.output_hidden_states = config.output_hidden_states
|
||||||
self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
self.layer = [TFMobileBertLayer(config, name="layer_._{}".format(i)) for i in range(config.num_hidden_layers)]
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states, training=False):
|
||||||
hidden_states, attention_mask, head_mask, output_attentions, output_hidden_states = inputs
|
|
||||||
|
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
all_attentions = ()
|
all_attentions = ()
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
[hidden_states, attention_mask, head_mask[i], output_attentions], training=training
|
hidden_states, attention_mask, head_mask[i], output_attentions, training=training
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[1],)
|
all_attentions = all_attentions + (layer_outputs[1],)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # outputs, (hidden states), (attentions)
|
return outputs # outputs, (hidden states), (attentions)
|
||||||
|
|
||||||
@@ -732,11 +755,14 @@ class TFMobileBertMainLayer(tf.keras.layers.Layer):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
else:
|
else:
|
||||||
head_mask = [None] * self.num_hidden_layers
|
head_mask = [None] * self.num_hidden_layers
|
||||||
# head_mask = tf.constant([0] * self.num_hidden_layers)
|
|
||||||
|
|
||||||
embedding_output = self.embeddings([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
[embedding_output, extended_attention_mask, head_mask, output_attentions, output_hidden_states],
|
embedding_output,
|
||||||
|
extended_attention_mask,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
output_hidden_states,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1360,8 +1386,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
if inputs_embeds is not None
|
if inputs_embeds is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
outputs = self.mobilebert(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
@@ -1370,16 +1395,12 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
|||||||
flat_inputs_embeds,
|
flat_inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
outputs = self.mobilebert(flat_inputs, training=training)
|
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=training)
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from .modeling_tf_utils import (
|
|||||||
TFPreTrainedModel,
|
TFPreTrainedModel,
|
||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -84,6 +83,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
self.n_head = config.n_head
|
self.n_head = config.n_head
|
||||||
self.split_size = n_state
|
self.split_size = n_state
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
self.c_attn = TFConv1D(n_state * 3, nx, initializer_range=config.initializer_range, name="c_attn")
|
||||||
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
self.c_proj = TFConv1D(n_state, nx, initializer_range=config.initializer_range, name="c_proj")
|
||||||
@@ -104,8 +104,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
m = i >= j - ns + nd
|
m = i >= j - ns + nd
|
||||||
return tf.cast(m, dtype)
|
return tf.cast(m, dtype)
|
||||||
|
|
||||||
def _attn(self, inputs, training=False):
|
def _attn(self, q, k, v, attention_mask, head_mask, output_attentions, training=False):
|
||||||
q, k, v, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
# q, k, v have shape [batch, heads, sequence, features]
|
# q, k, v have shape [batch, heads, sequence, features]
|
||||||
w = tf.matmul(q, k, transpose_b=True)
|
w = tf.matmul(q, k, transpose_b=True)
|
||||||
if self.scale:
|
if self.scale:
|
||||||
@@ -130,7 +129,7 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
w = w * head_mask
|
w = w * head_mask
|
||||||
|
|
||||||
outputs = [tf.matmul(w, v)]
|
outputs = [tf.matmul(w, v)]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs.append(w)
|
outputs.append(w)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -146,16 +145,14 @@ class TFAttention(tf.keras.layers.Layer):
|
|||||||
x = tf.reshape(x, new_x_shape)
|
x = tf.reshape(x, new_x_shape)
|
||||||
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
return tf.transpose(x, (0, 2, 1, 3)) # (batch, head, seq_length, head_features)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
||||||
x, attention_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
x = self.c_attn(x)
|
x = self.c_attn(x)
|
||||||
query, key, value = tf.split(x, 3, axis=2)
|
query, key, value = tf.split(x, 3, axis=2)
|
||||||
query = self.split_heads(query)
|
query = self.split_heads(query)
|
||||||
key = self.split_heads(key)
|
key = self.split_heads(key)
|
||||||
value = self.split_heads(value)
|
value = self.split_heads(value)
|
||||||
|
|
||||||
attn_outputs = self._attn([query, key, value, attention_mask, head_mask, output_attentions], training=training)
|
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions, training=training)
|
||||||
a = attn_outputs[0]
|
a = attn_outputs[0]
|
||||||
|
|
||||||
a = self.merge_heads(a)
|
a = self.merge_heads(a)
|
||||||
@@ -191,10 +188,8 @@ class TFBlock(tf.keras.layers.Layer):
|
|||||||
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
self.mlp = TFMLP(4 * nx, config, name="mlp")
|
||||||
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
self.ln_2 = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_epsilon, name="ln_2")
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, x, attention_mask, head_mask, output_attentions, training=False):
|
||||||
x, attention_mask, head_mask, output_attentions = inputs
|
output_attn = self.attn(x, attention_mask, head_mask, output_attentions, training=training)
|
||||||
|
|
||||||
output_attn = self.attn([x, attention_mask, head_mask, output_attentions], training=training)
|
|
||||||
a = output_attn[0] # output_attn: a, (attentions)
|
a = output_attn[0] # output_attn: a, (attentions)
|
||||||
|
|
||||||
n = self.ln_1(x + a)
|
n = self.ln_1(x + a)
|
||||||
@@ -341,23 +336,23 @@ class TFOpenAIGPTMainLayer(tf.keras.layers.Layer):
|
|||||||
all_attentions = []
|
all_attentions = []
|
||||||
all_hidden_states = ()
|
all_hidden_states = ()
|
||||||
for i, block in enumerate(self.h):
|
for i, block in enumerate(self.h):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
|
||||||
|
|
||||||
outputs = block([hidden_states, attention_mask, head_mask[i], output_attentions], training=training)
|
outputs = block(hidden_states, attention_mask, head_mask[i], output_attentions, training=training)
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions.append(outputs[1])
|
all_attentions.append(outputs[1])
|
||||||
|
|
||||||
hidden_states = tf.reshape(hidden_states, output_shape)
|
hidden_states = tf.reshape(hidden_states, output_shape)
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
# let the number of heads free (-1) so we can extract attention even after head pruning
|
# let the number of heads free (-1) so we can extract attention even after head pruning
|
||||||
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
|
||||||
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)
|
||||||
@@ -671,13 +666,11 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
input_shapes = shape_list(inputs_embeds)[:-1]
|
input_shapes = shape_list(inputs_embeds)[:-1]
|
||||||
|
|
||||||
seq_length = input_shapes[-1]
|
seq_length = input_shapes[-1]
|
||||||
|
|
||||||
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
flat_input_ids = tf.reshape(input_ids, (-1, seq_length)) if input_ids is not None else None
|
||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
@@ -686,18 +679,13 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
||||||
hidden_states = transformer_outputs[0]
|
hidden_states = transformer_outputs[0]
|
||||||
|
|
||||||
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
hidden_states = tf.reshape(hidden_states, input_shapes + shape_list(hidden_states)[-1:])
|
||||||
|
|
||||||
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
lm_logits = self.transformer.tokens_embed(hidden_states, mode="linear")
|
||||||
mc_logits = self.multiple_choice_head([hidden_states, mc_token_ids], training=training)
|
mc_logits = self.multiple_choice_head(hidden_states, mc_token_ids, training=training)
|
||||||
|
|
||||||
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
mc_logits = tf.squeeze(mc_logits, axis=-1)
|
||||||
|
|
||||||
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
outputs = (lm_logits, mc_logits) + transformer_outputs[1:]
|
||||||
|
|
||||||
return outputs # lm logits, mc logits, (all hidden_states), (attentions)
|
return outputs # lm logits, mc logits, (all hidden_states), (attentions)
|
||||||
|
|||||||
@@ -86,9 +86,9 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||||||
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
|
position_ids = tf.range(self.padding_idx + 1, seq_length + self.padding_idx + 1, dtype=tf.int32)[tf.newaxis, :]
|
||||||
return position_ids
|
return position_ids
|
||||||
|
|
||||||
def _embedding(self, inputs, training=False):
|
def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, training=False):
|
||||||
"""Applies embedding based on inputs tensor."""
|
"""Applies embedding based on inputs tensor."""
|
||||||
input_ids, position_ids, token_type_ids, inputs_embeds = inputs
|
assert not (input_ids is None and inputs_embeds is None)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
@@ -97,7 +97,7 @@ class TFRobertaEmbeddings(TFBertEmbeddings):
|
|||||||
else:
|
else:
|
||||||
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
|
||||||
|
|
||||||
return super()._embedding([input_ids, position_ids, token_type_ids, inputs_embeds], training=training)
|
return super()._embedding(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
|
||||||
|
|
||||||
|
|
||||||
@keras_serializable
|
@keras_serializable
|
||||||
@@ -546,8 +546,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||||||
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
flat_attention_mask = tf.reshape(attention_mask, (-1, seq_length)) if attention_mask is not None else None
|
||||||
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
flat_token_type_ids = tf.reshape(token_type_ids, (-1, seq_length)) if token_type_ids is not None else None
|
||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
|
outputs = self.roberta(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_token_type_ids,
|
flat_token_type_ids,
|
||||||
@@ -556,16 +555,12 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
outputs = self.roberta(flat_inputs, training=training)
|
|
||||||
|
|
||||||
pooled_output = outputs[1]
|
pooled_output = outputs[1]
|
||||||
|
|
||||||
pooled_output = self.dropout(pooled_output, training=training)
|
pooled_output = self.dropout(pooled_output, training=training)
|
||||||
logits = self.classifier(pooled_output)
|
logits = self.classifier(pooled_output)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + outputs[2:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
self.is_decoder = config.is_decoder
|
self.is_decoder = config.is_decoder
|
||||||
self.use_cache = config.use_cache
|
self.use_cache = config.use_cache
|
||||||
self.has_relative_attention_bias = has_relative_attention_bias
|
self.has_relative_attention_bias = has_relative_attention_bias
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
self.relative_attention_num_buckets = config.relative_attention_num_buckets
|
||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
@@ -296,7 +297,7 @@ class TFT5Attention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
outputs = (context,) + present_key_value_state
|
outputs = (context,) + present_key_value_state
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions, True) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (weights,)
|
outputs = outputs + (weights,)
|
||||||
if self.has_relative_attention_bias:
|
if self.has_relative_attention_bias:
|
||||||
outputs = outputs + (position_bias,)
|
outputs = outputs + (position_bias,)
|
||||||
@@ -699,7 +700,7 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states = self.dropout(inputs_embeds, training=training)
|
hidden_states = self.dropout(inputs_embeds, training=training)
|
||||||
|
|
||||||
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
for i, (layer_module, past_key_value_state) in enumerate(zip(self.block, past_key_value_states)):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
layer_outputs = layer_module(
|
layer_outputs = layer_module(
|
||||||
@@ -727,23 +728,23 @@ class TFT5MainLayer(tf.keras.layers.Layer):
|
|||||||
# append next layer key value states
|
# append next layer key value states
|
||||||
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
present_key_value_states = present_key_value_states + (present_key_value_state,)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
all_attentions = all_attentions + (layer_outputs[2],)
|
all_attentions = all_attentions + (layer_outputs[2],)
|
||||||
|
|
||||||
hidden_states = self.final_layer_norm(hidden_states)
|
hidden_states = self.final_layer_norm(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states, training=training)
|
hidden_states = self.dropout(hidden_states, training=training)
|
||||||
|
|
||||||
# Add last layer
|
# Add last layer
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
outputs = (hidden_states,)
|
outputs = (hidden_states,)
|
||||||
# need to check if is decoder here as well for special cases when using keras compile
|
# need to check if is decoder here as well for special cases when using keras compile
|
||||||
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
|
if cast_bool_to_primitive(use_cache, self.use_cache) is True and self.is_decoder:
|
||||||
outputs = outputs + (present_key_value_states,)
|
outputs = outputs + (present_key_value_states,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (all_hidden_states,)
|
outputs = outputs + (all_hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (all_attentions,)
|
outputs = outputs + (all_attentions,)
|
||||||
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
return outputs # last-layer hidden state, (all hidden states), (all attentions)
|
||||||
|
|
||||||
|
|||||||
@@ -24,13 +24,7 @@ import tensorflow as tf
|
|||||||
from .configuration_transfo_xl import TransfoXLConfig
|
from .configuration_transfo_xl import TransfoXLConfig
|
||||||
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
from .file_utils import add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_callable
|
||||||
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
from .modeling_tf_transfo_xl_utilities import TFAdaptiveSoftmaxMask
|
||||||
from .modeling_tf_utils import (
|
from .modeling_tf_utils import TFPreTrainedModel, get_initializer, keras_serializable, shape_list
|
||||||
TFPreTrainedModel,
|
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
|
||||||
keras_serializable,
|
|
||||||
shape_list,
|
|
||||||
)
|
|
||||||
from .tokenization_utils import BatchEncoding
|
from .tokenization_utils import BatchEncoding
|
||||||
|
|
||||||
|
|
||||||
@@ -119,6 +113,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
r_w_bias=None,
|
r_w_bias=None,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
init_std=0.02,
|
init_std=0.02,
|
||||||
|
output_attentions=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -127,6 +122,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
self.d_model = d_model
|
self.d_model = d_model
|
||||||
self.d_head = d_head
|
self.d_head = d_head
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
|
self.output_attentions = output_attentions
|
||||||
|
|
||||||
self.qkv_net = tf.keras.layers.Dense(
|
self.qkv_net = tf.keras.layers.Dense(
|
||||||
3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net"
|
3 * n_head * d_head, kernel_initializer=get_initializer(init_std), use_bias=False, name="qkv_net"
|
||||||
@@ -175,8 +171,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, w, r, attn_mask, mems, head_mask, output_attentions, training=False):
|
||||||
w, r, attn_mask, mems, head_mask, output_attentions = inputs
|
|
||||||
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
qlen, rlen, bsz = shape_list(w)[0], shape_list(r)[0], shape_list(w)[1]
|
||||||
|
|
||||||
if mems is not None:
|
if mems is not None:
|
||||||
@@ -249,7 +244,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
|
|||||||
# residual connection + layer normalization
|
# residual connection + layer normalization
|
||||||
outputs = [self.layer_norm(w + attn_out)]
|
outputs = [self.layer_norm(w + attn_out)]
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs.append(attn_prob)
|
outputs.append(attn_prob)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
@@ -272,6 +267,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
r_r_bias=None,
|
r_r_bias=None,
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
init_std=0.02,
|
init_std=0.02,
|
||||||
|
output_attentions=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
@@ -290,6 +286,7 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
r_r_bias=r_r_bias,
|
r_r_bias=r_r_bias,
|
||||||
init_std=init_std,
|
init_std=init_std,
|
||||||
layer_norm_epsilon=layer_norm_epsilon,
|
layer_norm_epsilon=layer_norm_epsilon,
|
||||||
|
output_attentions=output_attentions,
|
||||||
name="dec_attn",
|
name="dec_attn",
|
||||||
)
|
)
|
||||||
self.pos_ff = TFPositionwiseFF(
|
self.pos_ff = TFPositionwiseFF(
|
||||||
@@ -302,11 +299,8 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
|
|||||||
name="pos_ff",
|
name="pos_ff",
|
||||||
)
|
)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=False):
|
||||||
dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions = inputs
|
attn_outputs = self.dec_attn(dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions, training=training)
|
||||||
attn_outputs = self.dec_attn(
|
|
||||||
[dec_inp, r, dec_attn_mask, mems, head_mask, output_attentions], training=training
|
|
||||||
)
|
|
||||||
ff_output = self.pos_ff(attn_outputs[0], training=training)
|
ff_output = self.pos_ff(attn_outputs[0], training=training)
|
||||||
|
|
||||||
outputs = [ff_output] + attn_outputs[1:]
|
outputs = [ff_output] + attn_outputs[1:]
|
||||||
@@ -443,6 +437,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
r_r_bias=None if self.untie_r else self.r_r_bias,
|
r_r_bias=None if self.untie_r else self.r_r_bias,
|
||||||
layer_norm_epsilon=config.layer_norm_epsilon,
|
layer_norm_epsilon=config.layer_norm_epsilon,
|
||||||
init_std=config.init_std,
|
init_std=config.init_std,
|
||||||
|
output_attentions=self.output_attentions,
|
||||||
name="layers_._{}".format(i),
|
name="layers_._{}".format(i),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -625,10 +620,10 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
hids.append(core_out)
|
hids.append(core_out)
|
||||||
mems_i = None if mems is None else mems[i]
|
mems_i = None if mems is None else mems[i]
|
||||||
layer_outputs = layer(
|
layer_outputs = layer(
|
||||||
[core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions], training=training,
|
core_out, pos_emb, dec_attn_mask, mems_i, head_mask[i], output_attentions, training=training,
|
||||||
)
|
)
|
||||||
core_out = layer_outputs[0]
|
core_out = layer_outputs[0]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attentions.append(layer_outputs[1])
|
attentions.append(layer_outputs[1])
|
||||||
else: # learnable embeddings and absolute embeddings
|
else: # learnable embeddings and absolute embeddings
|
||||||
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
raise NotImplementedError # Removed these to avoid maintaining dead code - They are not used in our pretrained checkpoint
|
||||||
@@ -639,12 +634,12 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# We transpose back here to shape [bsz, len, hidden_dim]
|
# We transpose back here to shape [bsz, len, hidden_dim]
|
||||||
outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
|
outputs = [tf.transpose(core_out, perm=(1, 0, 2)), new_mems]
|
||||||
if cast_bool_to_primitive(output_hidden_states):
|
if output_hidden_states:
|
||||||
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
|
# Add last layer and transpose to library standard shape [bsz, len, hidden_dim]
|
||||||
hids.append(core_out)
|
hids.append(core_out)
|
||||||
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
|
hids = list(tf.transpose(t, perm=(1, 0, 2)) for t in hids)
|
||||||
outputs.append(hids)
|
outputs.append(hids)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
|
# Transpose to library standard shape [bsz, n_heads, query_seq_len, key_seq_len]
|
||||||
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
attentions = list(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||||
outputs.append(attentions)
|
outputs.append(attentions)
|
||||||
@@ -860,14 +855,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
|
|||||||
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
bsz, tgt_len = shape_list(inputs_embeds)[:2]
|
||||||
|
|
||||||
transformer_outputs = self.transformer(
|
transformer_outputs = self.transformer(
|
||||||
[input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states], training=training
|
input_ids, mems, head_mask, inputs_embeds, output_attentions, output_hidden_states, training=training
|
||||||
)
|
)
|
||||||
|
|
||||||
last_hidden = transformer_outputs[0]
|
last_hidden = transformer_outputs[0]
|
||||||
pred_hid = last_hidden[:, -tgt_len:]
|
pred_hid = last_hidden[:, -tgt_len:]
|
||||||
outputs = transformer_outputs[1:]
|
outputs = transformer_outputs[1:]
|
||||||
|
|
||||||
softmax_output = self.crit([pred_hid, labels], training=training)
|
softmax_output = self.crit(pred_hid, labels, training=training)
|
||||||
outputs = [softmax_output] + outputs
|
outputs = [softmax_output] + outputs
|
||||||
|
|
||||||
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
return outputs # logits, new_mems, (all hidden states), (all attentions)
|
||||||
|
|||||||
@@ -114,8 +114,7 @@ class TFAdaptiveSoftmaxMask(tf.keras.layers.Layer):
|
|||||||
idx = tf.stack([r, target], 1)
|
idx = tf.stack([r, target], 1)
|
||||||
return tf.gather_nd(logprob, idx)
|
return tf.gather_nd(logprob, idx)
|
||||||
|
|
||||||
def call(self, inputs, return_mean=True, training=False):
|
def call(self, hidden, target, return_mean=True, training=False):
|
||||||
hidden, target = inputs
|
|
||||||
head_logprob = 0
|
head_logprob = 0
|
||||||
if self.n_clusters == 0:
|
if self.n_clusters == 0:
|
||||||
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
|
output = self._logit(hidden, self.out_layers[0][0], self.out_layers[0][1], self.out_projs[0])
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from typing import Dict, List, Optional, Union
|
from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
import h5py
|
import h5py
|
||||||
@@ -173,7 +174,11 @@ class TFTokenClassificationLoss:
|
|||||||
)
|
)
|
||||||
# make sure only labels that are not equal to -100
|
# make sure only labels that are not equal to -100
|
||||||
# are taken into account as loss
|
# are taken into account as loss
|
||||||
active_loss = tf.reshape(labels, (-1,)) != -100
|
if tf.math.reduce_any(labels == -1).numpy() is True:
|
||||||
|
warnings.warn("Using `-1` to mask the loss for the token is deprecated. Please use `-100` instead.")
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -1
|
||||||
|
else:
|
||||||
|
active_loss = tf.reshape(labels, (-1,)) != -100
|
||||||
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
reduced_logits = tf.boolean_mask(tf.reshape(logits, (-1, shape_list(logits)[2])), active_loss)
|
||||||
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
labels = tf.boolean_mask(tf.reshape(labels, (-1,)), active_loss)
|
||||||
|
|
||||||
@@ -233,7 +238,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
|||||||
@property
|
@property
|
||||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||||
"""
|
"""
|
||||||
:obj:`Dict[str, tf.Tensor]`: Dummy inputs to build the network.
|
Dummy inputs to build the network.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`Dict[str, tf.Tensor]`: The dummy inputs.
|
||||||
"""
|
"""
|
||||||
return {"input_ids": tf.constant(DUMMY_INPUTS)}
|
return {"input_ids": tf.constant(DUMMY_INPUTS)}
|
||||||
|
|
||||||
@@ -774,14 +782,16 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||||||
return tf.gather(self.weight, input_ids)
|
return tf.gather(self.weight, input_ids)
|
||||||
|
|
||||||
def _linear(self, inputs):
|
def _linear(self, inputs):
|
||||||
"""Computes logits by running inputs through a linear layer.
|
"""
|
||||||
Args:
|
Computes logits by running inputs through a linear layer.
|
||||||
inputs: A float32 tensor with shape [..., hidden_size]
|
|
||||||
Returns:
|
Args:
|
||||||
float32 tensor with shape [..., vocab_size].
|
inputs: A float32 tensor with shape [..., hidden_size]
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float32 tensor with shape [..., vocab_size].
|
||||||
"""
|
"""
|
||||||
first_dims = shape_list(inputs)[:-1]
|
first_dims = shape_list(inputs)[:-1]
|
||||||
|
|
||||||
x = tf.reshape(inputs, [-1, self.hidden_size])
|
x = tf.reshape(inputs, [-1, self.hidden_size])
|
||||||
logits = tf.matmul(x, self.weight, transpose_b=True)
|
logits = tf.matmul(x, self.weight, transpose_b=True)
|
||||||
|
|
||||||
@@ -789,7 +799,7 @@ class TFSharedEmbeddings(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
|
|
||||||
class TFSequenceSummary(tf.keras.layers.Layer):
|
class TFSequenceSummary(tf.keras.layers.Layer):
|
||||||
r"""
|
"""
|
||||||
Compute a single vector summary of a sequence hidden states.
|
Compute a single vector summary of a sequence hidden states.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -852,26 +862,9 @@ class TFSequenceSummary(tf.keras.layers.Layer):
|
|||||||
if self.has_last_dropout:
|
if self.has_last_dropout:
|
||||||
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
self.last_dropout = tf.keras.layers.Dropout(config.summary_last_dropout)
|
||||||
|
|
||||||
def call(self, inputs, training=False) -> tf.Tensor:
|
def call(self, inputs, cls_index=None, training=False):
|
||||||
"""
|
|
||||||
Compute a single vector summary of a sequence hidden states.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
inputs (:obj:`Union[tf.Tensor, Tuple[tf.Tensor], List[tf.Tensor], Dict[str, tf.Tensor]]`):
|
|
||||||
One or two tensors representing:
|
|
||||||
|
|
||||||
- **hidden_states** (:obj:`tf.Tensor` of shape :obj:`[batch_size, seq_len, hidden_size]`) -- The hidden
|
|
||||||
states of the last layer.
|
|
||||||
- **cls_index** :obj:`tf.Tensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are
|
|
||||||
optional leading dimensions of :obj:`hidden_states`. Used if :obj:`summary_type == "cls_index"` and
|
|
||||||
takes the last token of the sequence as classification token.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
:obj:`tf.Tensor`: The summary of the sequence hidden states.
|
|
||||||
"""
|
|
||||||
if not isinstance(inputs, (dict, tuple, list)):
|
if not isinstance(inputs, (dict, tuple, list)):
|
||||||
hidden_states = inputs
|
hidden_states = inputs
|
||||||
cls_index = None
|
|
||||||
elif isinstance(inputs, (tuple, list)):
|
elif isinstance(inputs, (tuple, list)):
|
||||||
hidden_states = inputs[0]
|
hidden_states = inputs[0]
|
||||||
cls_index = inputs[1] if len(inputs) > 1 else None
|
cls_index = inputs[1] if len(inputs) > 1 else None
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ from .modeling_tf_utils import (
|
|||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -123,6 +122,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
|
self.layer_id = next(TFMultiHeadAttention.NEW_ID)
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
self.n_heads = n_heads
|
self.n_heads = n_heads
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
assert self.dim % self.n_heads == 0
|
assert self.dim % self.n_heads == 0
|
||||||
|
|
||||||
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
|
self.q_lin = tf.keras.layers.Dense(dim, kernel_initializer=get_initializer(config.init_std), name="q_lin")
|
||||||
@@ -135,11 +135,10 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
def prune_heads(self, heads):
|
def prune_heads(self, heads):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(self, input, mask, kv, cache, head_mask, output_attentions, training=False):
|
||||||
"""
|
"""
|
||||||
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
Self-attention (if kv is None) or attention over source sentence (provided by kv).
|
||||||
"""
|
"""
|
||||||
input, mask, kv, cache, head_mask, output_attentions = inputs
|
|
||||||
# Input is (bs, qlen, dim)
|
# Input is (bs, qlen, dim)
|
||||||
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
|
||||||
bs, qlen, dim = shape_list(input)
|
bs, qlen, dim = shape_list(input)
|
||||||
@@ -196,7 +195,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
|
|||||||
context = unshape(context) # (bs, qlen, dim)
|
context = unshape(context) # (bs, qlen, dim)
|
||||||
|
|
||||||
outputs = (self.out_lin(context),)
|
outputs = (self.out_lin(context),)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (weights,)
|
outputs = outputs + (weights,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -445,6 +444,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
inputs_embeds = self.embeddings(input_ids)
|
inputs_embeds = self.embeddings(input_ids)
|
||||||
|
|
||||||
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
tensor = inputs_embeds + self.position_embeddings(position_ids)
|
||||||
|
|
||||||
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
if langs is not None and self.use_lang_emb and self.n_langs > 1:
|
||||||
tensor = tensor + self.lang_embeddings(langs)
|
tensor = tensor + self.lang_embeddings(langs)
|
||||||
if token_type_ids is not None:
|
if token_type_ids is not None:
|
||||||
@@ -457,15 +457,15 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states = ()
|
hidden_states = ()
|
||||||
attentions = ()
|
attentions = ()
|
||||||
for i in range(self.n_layers):
|
for i in range(self.n_layers):
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states = hidden_states + (tensor,)
|
hidden_states = hidden_states + (tensor,)
|
||||||
|
|
||||||
# self attention
|
# self attention
|
||||||
attn_outputs = self.attentions[i](
|
attn_outputs = self.attentions[i](
|
||||||
[tensor, attn_mask, None, cache, head_mask[i], output_attentions], training=training
|
tensor, attn_mask, None, cache, head_mask[i], output_attentions, training=training
|
||||||
)
|
)
|
||||||
attn = attn_outputs[0]
|
attn = attn_outputs[0]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attentions = attentions + (attn_outputs[1],)
|
attentions = attentions + (attn_outputs[1],)
|
||||||
attn = self.dropout(attn, training=training)
|
attn = self.dropout(attn, training=training)
|
||||||
tensor = tensor + attn
|
tensor = tensor + attn
|
||||||
@@ -484,7 +484,7 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
tensor = tensor * mask[..., tf.newaxis]
|
tensor = tensor * mask[..., tf.newaxis]
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states = hidden_states + (tensor,)
|
hidden_states = hidden_states + (tensor,)
|
||||||
|
|
||||||
# update cache length
|
# update cache length
|
||||||
@@ -495,9 +495,9 @@ class TFXLMMainLayer(tf.keras.layers.Layer):
|
|||||||
# tensor = tensor.transpose(0, 1)
|
# tensor = tensor.transpose(0, 1)
|
||||||
|
|
||||||
outputs = (tensor,)
|
outputs = (tensor,)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
outputs = outputs + (hidden_states,)
|
outputs = outputs + (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (attentions,)
|
outputs = outputs + (attentions,)
|
||||||
return outputs # outputs, (hidden_states), (attentions)
|
return outputs # outputs, (hidden_states), (attentions)
|
||||||
|
|
||||||
@@ -930,7 +930,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
flat_position_ids = tf.reshape(position_ids, (-1, seq_length)) if position_ids is not None else None
|
||||||
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
flat_langs = tf.reshape(langs, (-1, seq_length)) if langs is not None else None
|
||||||
flat_inputs_embeds = (
|
flat_inputs_embeds = (
|
||||||
tf.reshape(inputs_embeds, (-1, inputs_embeds.shape[-2], inputs_embeds.shape[-1]))
|
tf.reshape(inputs_embeds, (-1, seq_length, shape_list(inputs_embeds)[3]))
|
||||||
if inputs_embeds is not None
|
if inputs_embeds is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
@@ -943,7 +943,7 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
)
|
)
|
||||||
lengths = None
|
lengths = None
|
||||||
|
|
||||||
flat_inputs = [
|
transformer_outputs = self.transformer(
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
flat_langs,
|
flat_langs,
|
||||||
@@ -955,14 +955,12 @@ class TFXLMForMultipleChoice(TFXLMPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
flat_inputs_embeds,
|
flat_inputs_embeds,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
logits = self.sequence_summary(output)
|
logits = self.sequence_summary(output)
|
||||||
logits = self.logits_proj(logits)
|
logits = self.logits_proj(logits)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from .modeling_tf_utils import (
|
|||||||
TFSequenceSummary,
|
TFSequenceSummary,
|
||||||
TFSharedEmbeddings,
|
TFSharedEmbeddings,
|
||||||
TFTokenClassificationLoss,
|
TFTokenClassificationLoss,
|
||||||
cast_bool_to_primitive,
|
|
||||||
get_initializer,
|
get_initializer,
|
||||||
keras_serializable,
|
keras_serializable,
|
||||||
shape_list,
|
shape_list,
|
||||||
@@ -92,6 +91,7 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
self.d_model = config.d_model
|
self.d_model = config.d_model
|
||||||
self.scale = 1 / (config.d_head ** 0.5)
|
self.scale = 1 / (config.d_head ** 0.5)
|
||||||
self.initializer_range = config.initializer_range
|
self.initializer_range = config.initializer_range
|
||||||
|
self.output_attentions = config.output_attentions
|
||||||
|
|
||||||
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
@@ -142,11 +142,10 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def rel_attn_core(self, inputs, training=False):
|
def rel_attn_core(
|
||||||
|
self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions, training=False
|
||||||
|
):
|
||||||
"""Core relative positional attention operations."""
|
"""Core relative positional attention operations."""
|
||||||
|
|
||||||
q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions = inputs
|
|
||||||
|
|
||||||
# content based attention score
|
# content based attention score
|
||||||
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
|
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
|
||||||
|
|
||||||
@@ -182,16 +181,14 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
# attention output
|
# attention output
|
||||||
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
|
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
return attn_vec, attn_prob
|
return attn_vec, attn_prob
|
||||||
|
|
||||||
return attn_vec
|
return attn_vec
|
||||||
|
|
||||||
def post_attention(self, inputs, residual=True, training=False):
|
def post_attention(self, h, attn_vec, residual=True, training=False):
|
||||||
"""Post-attention processing."""
|
"""Post-attention processing."""
|
||||||
# post-attention projection (back to `d_model`)
|
# post-attention projection (back to `d_model`)
|
||||||
h, attn_vec = inputs
|
|
||||||
|
|
||||||
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
|
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
|
||||||
|
|
||||||
attn_out = self.dropout(attn_out, training=training)
|
attn_out = self.dropout(attn_out, training=training)
|
||||||
@@ -202,9 +199,20 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(
|
||||||
(h, g, attn_mask_h, attn_mask_g, r, seg_mat, mems, target_mapping, head_mask, output_attentions) = inputs
|
self,
|
||||||
|
h,
|
||||||
|
g,
|
||||||
|
attn_mask_h,
|
||||||
|
attn_mask_g,
|
||||||
|
r,
|
||||||
|
seg_mat,
|
||||||
|
mems,
|
||||||
|
target_mapping,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
if g is not None:
|
if g is not None:
|
||||||
# Two-stream attention with relative positional encoding.
|
# Two-stream attention with relative positional encoding.
|
||||||
# content based attention score
|
# content based attention score
|
||||||
@@ -228,15 +236,22 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# core attention ops
|
# core attention ops
|
||||||
attn_vec_h = self.rel_attn_core(
|
attn_vec_h = self.rel_attn_core(
|
||||||
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
|
q_head_h,
|
||||||
|
k_head_h,
|
||||||
|
v_head_h,
|
||||||
|
k_head_r,
|
||||||
|
seg_mat,
|
||||||
|
attn_mask_h,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attn_vec_h, attn_prob_h = attn_vec_h
|
attn_vec_h, attn_prob_h = attn_vec_h
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention([h, attn_vec_h], training=training)
|
output_h = self.post_attention(h, attn_vec_h, training=training)
|
||||||
|
|
||||||
# g-stream
|
# g-stream
|
||||||
# query-stream query head
|
# query-stream query head
|
||||||
@@ -246,27 +261,41 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
if target_mapping is not None:
|
if target_mapping is not None:
|
||||||
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
|
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
|
||||||
attn_vec_g = self.rel_attn_core(
|
attn_vec_g = self.rel_attn_core(
|
||||||
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
|
q_head_g,
|
||||||
|
k_head_h,
|
||||||
|
v_head_h,
|
||||||
|
k_head_r,
|
||||||
|
seg_mat,
|
||||||
|
attn_mask_g,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attn_vec_g, attn_prob_g = attn_vec_g
|
attn_vec_g, attn_prob_g = attn_vec_g
|
||||||
|
|
||||||
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
|
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
|
||||||
else:
|
else:
|
||||||
attn_vec_g = self.rel_attn_core(
|
attn_vec_g = self.rel_attn_core(
|
||||||
[q_head_g, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_g, head_mask, output_attentions],
|
q_head_g,
|
||||||
|
k_head_h,
|
||||||
|
v_head_h,
|
||||||
|
k_head_r,
|
||||||
|
seg_mat,
|
||||||
|
attn_mask_g,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attn_vec_g, attn_prob_g = attn_vec_g
|
attn_vec_g, attn_prob_g = attn_vec_g
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_g = self.post_attention([g, attn_vec_g], training=training)
|
output_g = self.post_attention(g, attn_vec_g, training=training)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attn_prob = attn_prob_h, attn_prob_g
|
attn_prob = attn_prob_h, attn_prob_g
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -286,19 +315,26 @@ class TFXLNetRelativeAttention(tf.keras.layers.Layer):
|
|||||||
|
|
||||||
# core attention ops
|
# core attention ops
|
||||||
attn_vec = self.rel_attn_core(
|
attn_vec = self.rel_attn_core(
|
||||||
[q_head_h, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask_h, head_mask, output_attentions],
|
q_head_h,
|
||||||
|
k_head_h,
|
||||||
|
v_head_h,
|
||||||
|
k_head_r,
|
||||||
|
seg_mat,
|
||||||
|
attn_mask_h,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attn_vec, attn_prob = attn_vec
|
attn_vec, attn_prob = attn_vec
|
||||||
|
|
||||||
# post processing
|
# post processing
|
||||||
output_h = self.post_attention([h, attn_vec], training=training)
|
output_h = self.post_attention(h, attn_vec, training=training)
|
||||||
output_g = None
|
output_g = None
|
||||||
|
|
||||||
outputs = (output_h, output_g)
|
outputs = (output_h, output_g)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
outputs = outputs + (attn_prob,)
|
outputs = outputs + (attn_prob,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
@@ -337,8 +373,33 @@ class TFXLNetLayer(tf.keras.layers.Layer):
|
|||||||
self.ff = TFXLNetFeedForward(config, name="ff")
|
self.ff = TFXLNetFeedForward(config, name="ff")
|
||||||
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
self.dropout = tf.keras.layers.Dropout(config.dropout)
|
||||||
|
|
||||||
def call(self, inputs, training=False):
|
def call(
|
||||||
outputs = self.rel_attn(inputs, training=training)
|
self,
|
||||||
|
output_h,
|
||||||
|
output_g,
|
||||||
|
non_tgt_mask,
|
||||||
|
attn_mask,
|
||||||
|
pos_emb,
|
||||||
|
seg_mat,
|
||||||
|
mems,
|
||||||
|
target_mapping,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
training=False,
|
||||||
|
):
|
||||||
|
outputs = self.rel_attn(
|
||||||
|
output_h,
|
||||||
|
output_g,
|
||||||
|
non_tgt_mask,
|
||||||
|
attn_mask,
|
||||||
|
pos_emb,
|
||||||
|
seg_mat,
|
||||||
|
mems,
|
||||||
|
target_mapping,
|
||||||
|
head_mask,
|
||||||
|
output_attentions,
|
||||||
|
training=training,
|
||||||
|
)
|
||||||
output_h, output_g = outputs[:2]
|
output_h, output_g = outputs[:2]
|
||||||
|
|
||||||
if output_g is not None:
|
if output_g is not None:
|
||||||
@@ -686,32 +747,30 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
hidden_states = []
|
hidden_states = []
|
||||||
for i, layer_module in enumerate(self.layer):
|
for i, layer_module in enumerate(self.layer):
|
||||||
# cache new mems
|
# cache new mems
|
||||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
if self.mem_len is not None and self.mem_len > 0 and use_cache:
|
||||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||||
|
|
||||||
outputs = layer_module(
|
outputs = layer_module(
|
||||||
[
|
output_h,
|
||||||
output_h,
|
output_g,
|
||||||
output_g,
|
non_tgt_mask,
|
||||||
non_tgt_mask,
|
attn_mask,
|
||||||
attn_mask,
|
pos_emb,
|
||||||
pos_emb,
|
seg_mat,
|
||||||
seg_mat,
|
mems[i],
|
||||||
mems[i],
|
target_mapping,
|
||||||
target_mapping,
|
head_mask[i],
|
||||||
head_mask[i],
|
output_attentions,
|
||||||
output_attentions,
|
|
||||||
],
|
|
||||||
training=training,
|
training=training,
|
||||||
)
|
)
|
||||||
output_h, output_g = outputs[:2]
|
output_h, output_g = outputs[:2]
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attentions.append(outputs[2])
|
attentions.append(outputs[2])
|
||||||
|
|
||||||
# Add last hidden state
|
# Add last hidden state
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||||
|
|
||||||
output = self.dropout(output_g if output_g is not None else output_h, training=training)
|
output = self.dropout(output_g if output_g is not None else output_h, training=training)
|
||||||
@@ -719,16 +778,16 @@ class TFXLNetMainLayer(tf.keras.layers.Layer):
|
|||||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||||
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
|
outputs = (tf.transpose(output, perm=(1, 0, 2)),)
|
||||||
|
|
||||||
if self.mem_len is not None and self.mem_len > 0 and use_cache is True:
|
if self.mem_len is not None and self.mem_len > 0 and use_cache:
|
||||||
outputs = outputs + (new_mems,)
|
outputs = outputs + (new_mems,)
|
||||||
|
|
||||||
if cast_bool_to_primitive(output_hidden_states) is True:
|
if output_hidden_states:
|
||||||
if output_g is not None:
|
if output_g is not None:
|
||||||
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
|
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
|
||||||
else:
|
else:
|
||||||
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
|
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
|
||||||
outputs = outputs + (hidden_states,)
|
outputs = outputs + (hidden_states,)
|
||||||
if cast_bool_to_primitive(output_attentions) is True:
|
if output_attentions:
|
||||||
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
|
||||||
outputs = outputs + (attentions,)
|
outputs = outputs + (attentions,)
|
||||||
|
|
||||||
@@ -1240,8 +1299,7 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
if inputs_embeds is not None
|
if inputs_embeds is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
transformer_outputs = self.transformer(
|
||||||
flat_inputs = [
|
|
||||||
flat_input_ids,
|
flat_input_ids,
|
||||||
flat_attention_mask,
|
flat_attention_mask,
|
||||||
mems,
|
mems,
|
||||||
@@ -1254,14 +1312,12 @@ class TFXLNetForMultipleChoice(TFXLNetPreTrainedModel, TFMultipleChoiceLoss):
|
|||||||
use_cache,
|
use_cache,
|
||||||
output_attentions,
|
output_attentions,
|
||||||
output_hidden_states,
|
output_hidden_states,
|
||||||
]
|
training=training,
|
||||||
|
)
|
||||||
transformer_outputs = self.transformer(flat_inputs, training=training)
|
|
||||||
output = transformer_outputs[0]
|
output = transformer_outputs[0]
|
||||||
logits = self.sequence_summary(output)
|
logits = self.sequence_summary(output)
|
||||||
logits = self.logits_proj(logits)
|
logits = self.logits_proj(logits)
|
||||||
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
reshaped_logits = tf.reshape(logits, (-1, num_choices))
|
||||||
|
|
||||||
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
outputs = (reshaped_logits,) + transformer_outputs[1:] # add hidden states and attention if they are here
|
||||||
|
|
||||||
if labels is not None:
|
if labels is not None:
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import datetime
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Callable, Dict, Optional, Tuple
|
from typing import Callable, Dict, Optional, Tuple
|
||||||
|
|
||||||
@@ -25,15 +24,6 @@ if is_wandb_available():
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if parse(tf.__version__).release < (2, 2, 0):
|
|
||||||
logger.info(
|
|
||||||
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is {}".format(
|
|
||||||
tf.__version__
|
|
||||||
)
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
class TFTrainer:
|
class TFTrainer:
|
||||||
"""
|
"""
|
||||||
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
TFTrainer is a simple but feature-complete training and eval loop for TensorFlow,
|
||||||
@@ -77,6 +67,11 @@ class TFTrainer:
|
|||||||
None,
|
None,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
|
assert parse(tf.__version__).release >= (2, 2, 0), (
|
||||||
|
"You need to run the TensorFlow trainer with at least the version 2.2.0, your version is %r "
|
||||||
|
% tf.__version__
|
||||||
|
)
|
||||||
|
|
||||||
self.model = model
|
self.model = model
|
||||||
self.args = args
|
self.args = args
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ import unittest
|
|||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from transformers import is_tf_available, is_torch_available
|
from transformers import is_tf_available, is_torch_available
|
||||||
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf
|
from transformers.testing_utils import _tf_gpu_memory_limit, require_tf, slow
|
||||||
|
|
||||||
|
|
||||||
if is_tf_available():
|
if is_tf_available():
|
||||||
@@ -130,6 +130,61 @@ class TFModelTesterMixin:
|
|||||||
|
|
||||||
self.assert_outputs_same(after_outputs, outputs)
|
self.assert_outputs_same(after_outputs, outputs)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_hidden_states_output(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.output_hidden_states = True
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
num_out = len(model(inputs_dict))
|
||||||
|
model._saved_model_inputs_spec = None
|
||||||
|
model._set_save_spec(inputs_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
tf.saved_model.save(model, tmpdirname)
|
||||||
|
model = tf.keras.models.load_model(tmpdirname)
|
||||||
|
outputs = model(inputs_dict)
|
||||||
|
hidden_states = [t.numpy() for t in outputs[-1]]
|
||||||
|
self.assertEqual(len(outputs), num_out)
|
||||||
|
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(hidden_states[0].shape[-2:]), [self.model_tester.seq_length, self.model_tester.hidden_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_saved_model_with_attentions_output(self):
|
||||||
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
config.output_attentions = True
|
||||||
|
encoder_seq_length = (
|
||||||
|
self.model_tester.encoder_seq_length
|
||||||
|
if hasattr(self.model_tester, "encoder_seq_length")
|
||||||
|
else self.model_tester.seq_length
|
||||||
|
)
|
||||||
|
encoder_key_length = (
|
||||||
|
self.model_tester.key_length if hasattr(self.model_tester, "key_length") else encoder_seq_length
|
||||||
|
)
|
||||||
|
|
||||||
|
for model_class in self.all_model_classes:
|
||||||
|
inputs_dict = self._prepare_for_class(inputs_dict, model_class)
|
||||||
|
model = model_class(config)
|
||||||
|
num_out = len(model(inputs_dict))
|
||||||
|
model._saved_model_inputs_spec = None
|
||||||
|
model._set_save_spec(inputs_dict)
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||||
|
tf.saved_model.save(model, tmpdirname)
|
||||||
|
model = tf.keras.models.load_model(tmpdirname)
|
||||||
|
outputs = model(inputs_dict)
|
||||||
|
attentions = [t.numpy() for t in outputs[-1]]
|
||||||
|
self.assertEqual(len(outputs), num_out)
|
||||||
|
self.assertEqual(len(attentions), self.model_tester.num_hidden_layers)
|
||||||
|
self.assertListEqual(
|
||||||
|
list(attentions[0].shape[-3:]),
|
||||||
|
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
|
||||||
|
)
|
||||||
|
|
||||||
def test_keras_save_load(self):
|
def test_keras_save_load(self):
|
||||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||||
|
|
||||||
|
|||||||
@@ -342,11 +342,17 @@ class TFXLNetModelTester:
|
|||||||
"attention_mask": multiple_choice_input_mask,
|
"attention_mask": multiple_choice_input_mask,
|
||||||
"token_type_ids": multiple_choice_token_type_ids,
|
"token_type_ids": multiple_choice_token_type_ids,
|
||||||
}
|
}
|
||||||
(logits,) = model(inputs)
|
(logits, mems_1) = model(inputs)
|
||||||
result = {
|
result = {
|
||||||
|
"mems_1": [mem.numpy() for mem in mems_1],
|
||||||
"logits": logits.numpy(),
|
"logits": logits.numpy(),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
self.parent.assertListEqual(list(result["logits"].shape), [self.batch_size, self.num_choices])
|
||||||
|
self.parent.assertListEqual(
|
||||||
|
list(list(mem.shape) for mem in result["mems_1"]),
|
||||||
|
[[self.seq_length, self.batch_size * self.num_choices, self.hidden_size]] * self.num_hidden_layers,
|
||||||
|
)
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config_and_inputs = self.prepare_config_and_inputs()
|
config_and_inputs = self.prepare_config_and_inputs()
|
||||||
|
|||||||
Reference in New Issue
Block a user