adding option to desactivate past/memory outputs
This commit is contained in:
@@ -555,7 +555,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
@@ -581,6 +581,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
super(XLNetModel, self).__init__(config)
|
||||
self.output_attentions = config.output_attentions
|
||||
self.output_hidden_states = config.output_hidden_states
|
||||
self.output_past = config.output_past
|
||||
|
||||
self.mem_len = config.mem_len
|
||||
self.reuse_len = config.reuse_len
|
||||
@@ -637,16 +638,13 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
|
||||
def cache_mem(self, curr_out, prev_mem):
|
||||
"""cache hidden states into memory."""
|
||||
if self.mem_len is None or self.mem_len == 0:
|
||||
return None
|
||||
else:
|
||||
if self.reuse_len is not None and self.reuse_len > 0:
|
||||
curr_out = curr_out[:self.reuse_len]
|
||||
if self.reuse_len is not None and self.reuse_len > 0:
|
||||
curr_out = curr_out[:self.reuse_len]
|
||||
|
||||
if prev_mem is None:
|
||||
new_mem = curr_out[-self.mem_len:]
|
||||
else:
|
||||
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
|
||||
if prev_mem is None:
|
||||
new_mem = curr_out[-self.mem_len:]
|
||||
else:
|
||||
new_mem = torch.cat([prev_mem, curr_out], dim=0)[-self.mem_len:]
|
||||
|
||||
return new_mem.detach()
|
||||
|
||||
@@ -817,8 +815,9 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
attentions = []
|
||||
hidden_states = []
|
||||
for i, layer_module in enumerate(self.layer):
|
||||
# cache new mems
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
# cache new mems
|
||||
new_mems = new_mems + (self.cache_mem(output_h, mems[i]),)
|
||||
if self.output_hidden_states:
|
||||
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
|
||||
|
||||
@@ -836,7 +835,11 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
output = self.dropout(output_g if output_g is not None else output_h)
|
||||
|
||||
# Prepare outputs, we transpose back here to shape [bsz, len, hidden_dim] (cf. beginning of forward() method)
|
||||
outputs = (output.permute(1, 0, 2).contiguous(), new_mems)
|
||||
outputs = (output.permute(1, 0, 2).contiguous(),)
|
||||
|
||||
if self.mem_len is not None and self.mem_len > 0 and self.output_past:
|
||||
outputs = outputs + (new_mems,)
|
||||
|
||||
if self.output_hidden_states:
|
||||
if output_g is not None:
|
||||
hidden_states = tuple(h.permute(1, 0, 2).contiguous() for hs in hidden_states for h in hs)
|
||||
@@ -847,7 +850,7 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
attentions = tuple(t.permute(2, 3, 0, 1).contiguous() for t in attentions)
|
||||
outputs = outputs + (attentions,)
|
||||
|
||||
return outputs # outputs, new_mems, (hidden_states), (attentions)
|
||||
return outputs # outputs, (new_mems), (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a language modeling head on top
|
||||
@@ -867,7 +870,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
Language modeling loss.
|
||||
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
|
||||
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
||||
**mems**:
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
@@ -932,7 +935,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a sequence classification/regression head on top (a linear layer on top of
|
||||
@@ -951,7 +954,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
Classification (or regression if config.num_labels==1) loss.
|
||||
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
|
||||
Classification (or regression if config.num_labels==1) scores (before SoftMax).
|
||||
**mems**:
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
@@ -1011,7 +1014,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a multiple choice classification head on top (a linear layer on top of
|
||||
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """,
|
||||
@@ -1046,6 +1049,11 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
|
||||
of the input tensors. (see `input_ids` above).
|
||||
Classification scores (before SoftMax).
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
See details in the docstring of the `mems` input above.
|
||||
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
|
||||
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
|
||||
of shape ``(batch_size, sequence_length, hidden_size)``:
|
||||
@@ -1102,7 +1110,7 @@ class XLNetForMultipleChoice(XLNetPreTrainedModel):
|
||||
loss = loss_fct(reshaped_logits, labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # return (loss), logits, mems, (hidden states), (attentions)
|
||||
return outputs # return (loss), logits, (mems), (hidden states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@@ -1126,7 +1134,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
||||
Span-start scores (before SoftMax).
|
||||
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
|
||||
Span-end scores (before SoftMax).
|
||||
**mems**:
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
@@ -1197,7 +1205,7 @@ class XLNetForQuestionAnsweringSimple(XLNetPreTrainedModel):
|
||||
total_loss = (start_loss + end_loss) / 2
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
return outputs # (loss), start_logits, end_logits, (hidden_states), (attentions)
|
||||
return outputs # (loss), start_logits, end_logits, (mems), (hidden_states), (attentions)
|
||||
|
||||
|
||||
@add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
|
||||
@@ -1239,7 +1247,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
**cls_logits**: (`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
**mems**:
|
||||
**mems**: (`optional`, returned when ``config.mem_len > 0``)
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
|
||||
|
||||
Reference in New Issue
Block a user