Output hidden states (#4978)

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Configure all models to use output_hidden_states as argument passed to foward()

* Pass all tests

* Remove cast_bool_to_primitive in TF Flaubert model

* correct tf xlnet

* add pytorch test

* add tf test

* Fix broken tests

* Refactor output_hidden_states for mobilebert

* Reset and remerge to master

Co-authored-by: Joseph Liu <joseph.liu@coinflex.com>
Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
Joseph Liu
2020-06-22 22:10:45 +08:00
committed by GitHub
parent 866a8ccabb
commit f4e1f02210
34 changed files with 814 additions and 349 deletions

View File

@@ -514,7 +514,6 @@ class MobileBertLayer(nn.Module):
class MobileBertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.output_hidden_states = config.output_hidden_states
self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
@@ -525,11 +524,12 @@ class MobileBertEncoder(nn.Module):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=False,
output_hidden_states=False,
):
all_hidden_states = ()
all_attentions = ()
for i, layer_module in enumerate(self.layer):
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_outputs = layer_module(
@@ -546,11 +546,11 @@ class MobileBertEncoder(nn.Module):
all_attentions = all_attentions + (layer_outputs[1],)
# Add last layer
if self.output_hidden_states:
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
outputs = (hidden_states,)
if self.output_hidden_states:
if output_hidden_states:
outputs = outputs + (all_hidden_states,)
if output_attentions:
outputs = outputs + (all_attentions,)
@@ -757,6 +757,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_hidden_states=None,
output_attentions=None,
):
r"""
@@ -773,7 +774,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -801,6 +802,9 @@ class MobileBertModel(MobileBertPreTrainedModel):
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@@ -852,6 +856,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output)
@@ -911,6 +916,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
labels=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
@@ -932,7 +938,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False
continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -962,6 +968,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output, pooled_output = outputs[:2]
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
@@ -1027,6 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
**kwargs
):
r"""
@@ -1044,7 +1052,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
Masked language modeling loss.
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -1087,6 +1095,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@@ -1136,6 +1145,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
inputs_embeds=None,
next_sentence_label=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@@ -1150,7 +1160,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
Next sequence prediction (classification) loss.
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -1186,6 +1196,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@@ -1227,6 +1238,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@@ -1240,7 +1252,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
Classification (or regression if config.num_labels==1) loss.
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
Classification (or regression if config.num_labels==1) scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
@@ -1273,6 +1285,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
pooled_output = self.dropout(pooled_output)
@@ -1317,6 +1330,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
start_positions=None,
end_positions=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@@ -1336,7 +1350,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
Span-start scores (before SoftMax).
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
Span-end scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -1376,6 +1390,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]
@@ -1432,6 +1447,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
@@ -1447,7 +1463,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -1498,6 +1514,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
pooled_output = outputs[1]
@@ -1543,6 +1560,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
@@ -1555,7 +1573,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
Classification loss.
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
Classification scores (before SoftMax).
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
@@ -1591,6 +1609,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
)
sequence_output = outputs[0]