update tokenizer - update squad example for xlnet
This commit is contained in:
@@ -528,9 +528,9 @@ class PoolerEndLogits(nn.Module):
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
|
||||
start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
|
||||
@@ -571,7 +571,7 @@ class PoolerAnswerClass(nn.Module):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
hsz = hidden_states.shape[-1]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
if start_positions is not None:
|
||||
start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
|
||||
@@ -614,12 +614,21 @@ class SQuADHead(nn.Module):
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
**start_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Log probabilities for the top config.start_n_top start token possibilities (beam-search).
|
||||
**start_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top)``
|
||||
Indices for the top config.start_n_top start token possibilities (beam-search).
|
||||
**end_top_log_probs**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Log probabilities for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**end_top_index**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.LongTensor`` of shape ``(batch_size, config.start_n_top * config.end_n_top)``
|
||||
Indices for the top ``config.start_n_top * config.end_n_top`` end token possibilities (beam-search).
|
||||
**cls_logits**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided)
|
||||
``torch.FloatTensor`` of shape ``(batch_size,)``
|
||||
Log probabilities for the ``is_impossible`` label of the answers.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
@@ -667,8 +676,8 @@ class SQuADHead(nn.Module):
|
||||
start_log_probs = F.softmax(start_logits, dim=-1) # shape (bsz, slen)
|
||||
|
||||
start_top_log_probs, start_top_index = torch.topk(start_log_probs, self.start_n_top, dim=-1) # shape (bsz, start_n_top)
|
||||
start_top_index = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index) # shape (bsz, start_n_top, hsz)
|
||||
start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
|
||||
start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
|
||||
start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(start_states) # shape (bsz, slen, start_n_top, hsz)
|
||||
|
||||
Reference in New Issue
Block a user