Fix saved model creation (#5468)

* Fix TF Serving when output_hidden_states and output_attentions are True

* Add tests for saved model creation + bug fix for multiple choices models

* remove unused import

* Fix the input for several layers

* Fix test

* Fix conflict printing

* Apply style

* Fix XLM and Flaubert for TensorFlow

* Apply style

* Fix TF check version

* Apply style

* Trigger CI
This commit is contained in:
Julien Plu
2020-08-03 14:10:40 +02:00
committed by GitHub
parent 5a0dac53bf
commit 9996f697e3
19 changed files with 572 additions and 448 deletions

View File

@@ -27,7 +27,6 @@ from .modeling_tf_utils import (
TFCausalLanguageModelingLoss,
TFPreTrainedModel,
TFSharedEmbeddings,
cast_bool_to_primitive,
keras_serializable,
shape_list,
)
@@ -87,10 +86,11 @@ def scaled_dot_product_attention(q, k, v, mask, attention_mask=None, head_mask=N
class TFMultiHeadAttention(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, **kwargs):
def __init__(self, d_model_size, num_heads, output_attentions=False, **kwargs):
super().__init__(**kwargs)
self.num_heads = num_heads
self.d_model_size = d_model_size
self.output_attentions = output_attentions
self.depth = int(d_model_size / self.num_heads)
@@ -104,8 +104,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
return tf.transpose(x, perm=[0, 2, 1, 3])
def call(self, inputs, training=False):
v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
def call(self, v, k, q, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
batch_size = shape_list(q)[0]
q = self.Wq(q)
@@ -121,10 +120,7 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
k = tf.concat((past_key, k), axis=-2)
v = tf.concat((past_value, v), axis=-2)
# to cope with keras serialization
use_cache = cast_bool_to_primitive(use_cache, True)
if use_cache is True:
if use_cache:
present = tf.stack((k, v), axis=0)
else:
present = (None,)
@@ -134,10 +130,11 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
attn = output[1]
original_size_attention = tf.reshape(scaled_attention, (batch_size, -1, self.d_model_size))
output = self.dense(original_size_attention)
outputs = (output, present)
if cast_bool_to_primitive(output_attentions) is True:
if output_attentions:
outputs = outputs + (attn,)
return outputs
@@ -156,10 +153,16 @@ class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
class TFEncoderLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, **kwargs):
def __init__(
self, d_model_size, num_heads, dff, rate=0.1, layer_norm_epsilon=1e-6, output_attentions=False, **kwargs
):
super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
self.output_attentions = output_attentions
self.multi_head_attention = TFMultiHeadAttention(
d_model_size, num_heads, output_attentions=self.output_attentions, name="multi_head_attention"
)
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
@@ -168,11 +171,18 @@ class TFEncoderLayer(tf.keras.layers.Layer):
self.dropout1 = tf.keras.layers.Dropout(rate)
self.dropout2 = tf.keras.layers.Dropout(rate)
def call(self, inputs, training=False):
x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions = inputs
def call(self, x, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions, training=False):
normed = self.layernorm1(x)
attn_outputs = self.multi_head_attention(
[normed, normed, normed, mask, layer_past, attention_mask, head_mask, use_cache, output_attentions],
normed,
normed,
normed,
mask,
layer_past,
attention_mask,
head_mask,
use_cache,
output_attentions,
training=training,
)
attn_output = attn_outputs[0]
@@ -215,6 +225,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
config.dff,
config.resid_pdrop,
config.layer_norm_epsilon,
self.output_attentions,
name="h_._{}".format(i),
)
for i in range(config.n_layer)
@@ -367,31 +378,37 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if cast_bool_to_primitive(output_hidden_states) is True:
if output_hidden_states:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
hidden_states,
mask,
layer_past,
attention_mask,
head_mask[i],
use_cache,
output_attentions,
training=training,
)
hidden_states, present = outputs[:2]
if use_cache is True:
if use_cache:
presents = presents + (present,)
if cast_bool_to_primitive(output_attentions) is True:
if output_attentions:
all_attentions.append(outputs[2])
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if cast_bool_to_primitive(output_hidden_states) is True:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
if use_cache:
outputs = outputs + (presents,)
if cast_bool_to_primitive(output_hidden_states) is True:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
if output_attentions:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape = input_shape[:-1] + [-1] + shape_list(all_attentions[0])[-2:]
all_attentions = tuple(tf.reshape(t, attention_output_shape) for t in all_attentions)