Output hidden states (#4978)
* Configure all models to use output_hidden_states as argument passed to foward() * Pass all tests * Remove cast_bool_to_primitive in TF Flaubert model * correct tf xlnet * add pytorch test * add tf test * Fix broken tests * Configure all models to use output_hidden_states as argument passed to foward() * Pass all tests * Remove cast_bool_to_primitive in TF Flaubert model * correct tf xlnet * add pytorch test * add tf test * Fix broken tests * Refactor output_hidden_states for mobilebert * Reset and remerge to master Co-authored-by: Joseph Liu <joseph.liu@coinflex.com> Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -514,7 +514,6 @@ class MobileBertLayer(nn.Module):
|
||||
class MobileBertEncoder(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
@@ -525,11 +524,12 @@ class MobileBertEncoder(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
):
|
||||
all_hidden_states = ()
|
||||
all_attentions = ()
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
if self.output_hidden_states:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
layer_outputs = layer_module(
|
||||
@@ -546,11 +546,11 @@ class MobileBertEncoder(nn.Module):
|
||||
all_attentions = all_attentions + (layer_outputs[1],)
|
||||
|
||||
# Add last layer
|
||||
if self.output_hidden_states:
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
outputs = (hidden_states,)
|
||||
if self.output_hidden_states:
|
||||
if output_hidden_states:
|
||||
outputs = outputs + (all_hidden_states,)
|
||||
if output_attentions:
|
||||
outputs = outputs + (all_attentions,)
|
||||
@@ -757,6 +757,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_hidden_states=None,
|
||||
output_attentions=None,
|
||||
):
|
||||
r"""
|
||||
@@ -773,7 +774,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
This output is usually *not* a good summary
|
||||
of the semantic content of the input, you're often better with averaging or pooling
|
||||
the sequence of hidden-states for the whole input sequence.
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -801,6 +802,9 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
@@ -852,6 +856,7 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_extended_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output = encoder_outputs[0]
|
||||
pooled_output = self.pooler(sequence_output)
|
||||
@@ -911,6 +916,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
labels=None,
|
||||
next_sentence_label=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
labels (``torch.LongTensor`` of shape ``(batch_size, sequence_length)``, `optional`, defaults to :obj:`None`):
|
||||
@@ -932,7 +938,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False
|
||||
continuation before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when :obj:`config.output_hidden_states=True`):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -962,6 +968,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
sequence_output, pooled_output = outputs[:2]
|
||||
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
|
||||
@@ -1027,6 +1034,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
encoder_hidden_states=None,
|
||||
encoder_attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
**kwargs
|
||||
):
|
||||
r"""
|
||||
@@ -1044,7 +1052,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
Masked language modeling loss.
|
||||
prediction_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.vocab_size)`)
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -1087,6 +1095,7 @@ class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1136,6 +1145,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
next_sentence_label=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
next_sentence_label (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1150,7 +1160,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
Next sequence prediction (classification) loss.
|
||||
seq_relationship_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2)`):
|
||||
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -1186,6 +1196,7 @@ class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1227,6 +1238,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1240,7 +1252,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
||||
@@ -1273,6 +1285,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
pooled_output = outputs[1]
|
||||
pooled_output = self.dropout(pooled_output)
|
||||
@@ -1317,6 +1330,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
start_positions=None,
|
||||
end_positions=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
start_positions (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1336,7 +1350,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
Span-start scores (before SoftMax).
|
||||
end_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length,)`):
|
||||
Span-end scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -1376,6 +1390,7 @@ class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
@@ -1432,6 +1447,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1447,7 +1463,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
`num_choices` is the second dimension of the input tensors. (see `input_ids` above).
|
||||
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -1498,6 +1514,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
pooled_output = outputs[1]
|
||||
@@ -1543,6 +1560,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
inputs_embeds=None,
|
||||
labels=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`, defaults to :obj:`None`):
|
||||
@@ -1555,7 +1573,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
Classification loss.
|
||||
scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, config.num_labels)`)
|
||||
Classification scores (before SoftMax).
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
|
||||
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or ``config.output_hidden_states=True``):
|
||||
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
||||
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
||||
|
||||
@@ -1591,6 +1609,7 @@ class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
head_mask=head_mask,
|
||||
inputs_embeds=inputs_embeds,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
)
|
||||
|
||||
sequence_output = outputs[0]
|
||||
|
||||
Reference in New Issue
Block a user