updating squad for compatibility with XLNet
This commit is contained in:
@@ -493,8 +493,9 @@ class PoolerStartLogits(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, p_mask=None):
|
||||
""" Args:
|
||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
||||
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
@@ -516,11 +517,16 @@ class PoolerEndLogits(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
||||
""" Args:
|
||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
@@ -549,13 +555,21 @@ class PoolerAnswerClass(nn.Module):
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
||||
""" Args:
|
||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
||||
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
|
||||
"""
|
||||
Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
|
||||
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
|
||||
note(Original repo):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
@@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module):
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
""" A SQuAD head inspired by XLNet.
|
||||
r""" A SQuAD head inspired by XLNet.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||
|
||||
Inputs:
|
||||
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
||||
hidden states of sequence tokens
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the last token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
@@ -590,8 +632,6 @@ class SQuADHead(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
||||
cls_index=None, is_impossible=None, p_mask=None):
|
||||
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
||||
"""
|
||||
outputs = ()
|
||||
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
@@ -618,9 +658,8 @@ class SQuADHead(nn.Module):
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
||||
else:
|
||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -647,7 +686,7 @@ class SQuADHead(nn.Module):
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -1162,8 +1162,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
|
||||
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
|
||||
Reference in New Issue
Block a user