Output hidden states (#4978)

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Refactor output_hidden_states for mobilebert

* Reset and remerge to master

Co-authored-by: Joseph Liu <joseph.liu@coinflex.com>
Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Joseph Liu
2020-06-22 22:10:45 +08:00
committed by GitHub
parent 866a8ccabb
commit f4e1f02210
34 changed files with 814 additions and 349 deletions

View File

@@ -237,6 +237,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds=None,
use_cache=True,
output_attentions=None,
output_hidden_states=None,
training=False,
):
@@ -250,7 +251,8 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs[6] if len(inputs) > 6 else inputs_embeds
use_cache = inputs[7] if len(inputs) > 7 else use_cache
output_attentions = inputs[8] if len(inputs) > 8 else output_attentions
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs[9] if len(inputs) > 9 else output_hidden_states
assert len(inputs) <= 10, "Too many inputs."
elif isinstance(inputs, (dict, BatchEncoding)):
input_ids = inputs.get("input_ids")
past = inputs.get("past", past)
@@ -261,11 +263,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
inputs_embeds = inputs.get("inputs_embeds", inputs_embeds)
use_cache = inputs.get("use_cache", use_cache)
output_attentions = inputs.get("output_attentions", output_attentions)
assert len(inputs) <= 9, "Too many inputs."
output_hidden_states = inputs.get("output_hidden_states", output_hidden_states)
assert len(inputs) <= 10, "Too many inputs."
else:
input_ids = inputs
output_attentions = output_attentions if output_attentions is not None else self.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.output_hidden_states
# If using past key value states, only the last tokens
# should be given as an input
@@ -351,7 +355,7 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
all_hidden_states = ()
all_attentions = []
for i, (h, layer_past) in enumerate(zip(self.h, past)):
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (tf.reshape(hidden_states, output_shape),)
outputs = h(
[hidden_states, mask, layer_past, attention_mask, head_mask[i], use_cache, output_attentions],
@@ -367,13 +371,13 @@ class TFCTRLMainLayer(tf.keras.layers.Layer):
hidden_states = self.layernorm(hidden_states)
hidden_states = tf.reshape(hidden_states, output_shape)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if use_cache is True:
outputs = outputs + (presents,)
if self.output_hidden_states:
if cast_bool_to_primitive(output_hidden_states) is True:
outputs = outputs + (all_hidden_states,)
if cast_bool_to_primitive(output_attentions) is True:
# let the number of heads free (-1) so we can extract attention even after head pruning
@@ -493,7 +497,7 @@ class TFCTRLModel(TFCTRLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)` `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -573,7 +577,7 @@ class TFCTRLLMHeadModel(TFCTRLPreTrainedModel):
Contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding. The token ids which have their past given to this model
should not be passed as input ids as they have already been computed.
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(tf.Tensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
Tuple of :obj:`tf.Tensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.