updating squad for compatibility with XLNet
This commit is contained in:
@@ -41,7 +41,9 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
|||||||
|
|
||||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||||
|
|
||||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||||
|
RawResult, write_predictions,
|
||||||
|
RawResultExtended, write_predictions_extended)
|
||||||
|
|
||||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||||
# You can remove it from the dependencies if you are using this script outside of the library
|
# You can remove it from the dependencies if you are using this script outside of the library
|
||||||
@@ -66,6 +68,8 @@ def set_seed(args):
|
|||||||
if args.n_gpu > 0:
|
if args.n_gpu > 0:
|
||||||
torch.cuda.manual_seed_all(args.seed)
|
torch.cuda.manual_seed_all(args.seed)
|
||||||
|
|
||||||
|
def to_list(tensor):
|
||||||
|
return tensor.detach().cpu().tolist()
|
||||||
|
|
||||||
def train(args, train_dataset, model, tokenizer):
|
def train(args, train_dataset, model, tokenizer):
|
||||||
""" Train the model """
|
""" Train the model """
|
||||||
@@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer):
|
|||||||
model.train()
|
model.train()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||||
'attention_mask': batch[2],
|
'attention_mask': batch[2],
|
||||||
'start_positions': batch[3],
|
'start_positions': batch[3],
|
||||||
'end_positions': batch[4]}
|
'end_positions': batch[4]}
|
||||||
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
|
inputs.update({'cls_index': batch[5],
|
||||||
|
'p_mask': batch[6]})
|
||||||
ouputs = model(**inputs)
|
ouputs = model(**inputs)
|
||||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||||
|
|
||||||
@@ -197,30 +204,49 @@ def evaluate(args, model, tokenizer, prefix=""):
|
|||||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||||
model.eval()
|
model.eval()
|
||||||
batch = tuple(t.to(args.device) for t in batch)
|
batch = tuple(t.to(args.device) for t in batch)
|
||||||
example_indices = batch[3]
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
inputs = {'input_ids': batch[0],
|
inputs = {'input_ids': batch[0],
|
||||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||||
'attention_mask': batch[2]}
|
'attention_mask': batch[2]}
|
||||||
|
example_indices = batch[3]
|
||||||
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
|
inputs.update({'cls_index': batch[4],
|
||||||
|
'p_mask': batch[5]})
|
||||||
outputs = model(**inputs)
|
outputs = model(**inputs)
|
||||||
batch_start_logits, batch_end_logits = outputs[:2]
|
batch_start_logits, batch_end_logits = outputs[:2]
|
||||||
|
|
||||||
for i, example_index in enumerate(example_indices):
|
for i, example_index in enumerate(example_indices):
|
||||||
start_logits = batch_start_logits[i].detach().cpu().tolist()
|
|
||||||
end_logits = batch_end_logits[i].detach().cpu().tolist()
|
|
||||||
eval_feature = features[example_index.item()]
|
eval_feature = features[example_index.item()]
|
||||||
unique_id = int(eval_feature.unique_id)
|
unique_id = int(eval_feature.unique_id)
|
||||||
all_results.append(RawResult(unique_id=unique_id,
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
start_logits=start_logits,
|
# XLNet uses a more complex post-processing procedure
|
||||||
end_logits=end_logits))
|
result = RawResultExtended(unique_id = unique_id,
|
||||||
|
start_top_log_probs = to_list(outputs[0][i]),
|
||||||
|
start_top_index = to_list(outputs[1][i]),
|
||||||
|
end_top_log_probs = to_list(outputs[2][i]),
|
||||||
|
end_top_index = to_list(outputs[3][i]),
|
||||||
|
cls_logits = to_list(outputs[4][i]))
|
||||||
|
else:
|
||||||
|
result = RawResult(unique_id = unique_id,
|
||||||
|
start_logits = to_list(outputs[0][i]),
|
||||||
|
end_logits = to_list(outputs[1][i]))
|
||||||
|
all_results.append(result)
|
||||||
|
|
||||||
# Compute predictions
|
# Compute predictions
|
||||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||||
write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
|
|
||||||
args.do_lower_case, output_prediction_file, output_nbest_file,
|
if args.model_type in ['xlnet', 'xlm']:
|
||||||
output_null_log_odds_file, args.verbose_logging,
|
# XLNet uses a more complex post-processing procedure
|
||||||
|
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||||
|
args.max_answer_length, output_prediction_file,
|
||||||
|
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||||
|
args.start_n_top, args.end_n_top, args.version_2_with_negative)
|
||||||
|
else:
|
||||||
|
write_predictions(examples, features, all_results, args.n_best_size,
|
||||||
|
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||||
|
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||||
|
|
||||||
# Evaluate with the official SQuAD script
|
# Evaluate with the official SQuAD script
|
||||||
@@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||||
|
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||||
|
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||||
if evaluate:
|
if evaluate:
|
||||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||||
|
all_example_index, all_cls_index, all_p_mask)
|
||||||
else:
|
else:
|
||||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||||
|
all_start_positions, all_end_positions,
|
||||||
|
all_cls_index, all_p_mask)
|
||||||
|
|
||||||
if output_examples:
|
if output_examples:
|
||||||
return dataset, examples, features
|
return dataset, examples, features
|
||||||
|
|||||||
@@ -26,6 +26,9 @@ from io import open
|
|||||||
|
|
||||||
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||||
|
|
||||||
|
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
||||||
|
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -82,6 +85,8 @@ class InputFeatures(object):
|
|||||||
input_ids,
|
input_ids,
|
||||||
input_mask,
|
input_mask,
|
||||||
segment_ids,
|
segment_ids,
|
||||||
|
cls_index,
|
||||||
|
p_mask,
|
||||||
start_position=None,
|
start_position=None,
|
||||||
end_position=None,
|
end_position=None,
|
||||||
is_impossible=None):
|
is_impossible=None):
|
||||||
@@ -94,6 +99,8 @@ class InputFeatures(object):
|
|||||||
self.input_ids = input_ids
|
self.input_ids = input_ids
|
||||||
self.input_mask = input_mask
|
self.input_mask = input_mask
|
||||||
self.segment_ids = segment_ids
|
self.segment_ids = segment_ids
|
||||||
|
self.cls_index = cls_index
|
||||||
|
self.p_mask = p_mask
|
||||||
self.start_position = start_position
|
self.start_position = start_position
|
||||||
self.end_position = end_position
|
self.end_position = end_position
|
||||||
self.is_impossible = is_impossible
|
self.is_impossible = is_impossible
|
||||||
@@ -178,13 +185,25 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
|||||||
|
|
||||||
|
|
||||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||||
doc_stride, max_query_length, is_training):
|
doc_stride, max_query_length, is_training,
|
||||||
|
cls_token_at_end=False,
|
||||||
|
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||||
|
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||||
|
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||||
|
mask_padding_with_zero=True):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
unique_id = 1000000000
|
unique_id = 1000000000
|
||||||
|
# cnt_pos, cnt_neg = 0, 0
|
||||||
|
# max_N, max_M = 1024, 1024
|
||||||
|
# f = np.zeros((max_N, max_M), dtype=np.float32)
|
||||||
|
|
||||||
features = []
|
features = []
|
||||||
for (example_index, example) in enumerate(examples):
|
for (example_index, example) in enumerate(examples):
|
||||||
|
|
||||||
|
# if example_index % 100 == 0:
|
||||||
|
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
|
||||||
|
|
||||||
query_tokens = tokenizer.tokenize(example.question_text)
|
query_tokens = tokenizer.tokenize(example.question_text)
|
||||||
|
|
||||||
if len(query_tokens) > max_query_length:
|
if len(query_tokens) > max_query_length:
|
||||||
@@ -239,14 +258,30 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
token_to_orig_map = {}
|
token_to_orig_map = {}
|
||||||
token_is_max_context = {}
|
token_is_max_context = {}
|
||||||
segment_ids = []
|
segment_ids = []
|
||||||
tokens.append("[CLS]")
|
|
||||||
segment_ids.append(0)
|
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||||
|
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||||
|
p_mask = []
|
||||||
|
|
||||||
|
# CLS token at the beginning
|
||||||
|
if not cls_token_at_end:
|
||||||
|
tokens.append(cls_token)
|
||||||
|
segment_ids.append(cls_token_segment_id)
|
||||||
|
p_mask.append(0)
|
||||||
|
cls_index = 0
|
||||||
|
|
||||||
|
# Query
|
||||||
for token in query_tokens:
|
for token in query_tokens:
|
||||||
tokens.append(token)
|
tokens.append(token)
|
||||||
segment_ids.append(0)
|
segment_ids.append(sequence_a_segment_id)
|
||||||
tokens.append("[SEP]")
|
p_mask.append(1)
|
||||||
segment_ids.append(0)
|
|
||||||
|
|
||||||
|
# SEP token
|
||||||
|
tokens.append(sep_token)
|
||||||
|
segment_ids.append(sequence_a_segment_id)
|
||||||
|
p_mask.append(1)
|
||||||
|
|
||||||
|
# Paragraph
|
||||||
for i in range(doc_span.length):
|
for i in range(doc_span.length):
|
||||||
split_token_index = doc_span.start + i
|
split_token_index = doc_span.start + i
|
||||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||||
@@ -255,29 +290,42 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
split_token_index)
|
split_token_index)
|
||||||
token_is_max_context[len(tokens)] = is_max_context
|
token_is_max_context[len(tokens)] = is_max_context
|
||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
segment_ids.append(1)
|
segment_ids.append(sequence_b_segment_id)
|
||||||
tokens.append("[SEP]")
|
p_mask.append(0)
|
||||||
segment_ids.append(1)
|
|
||||||
|
# SEP token
|
||||||
|
tokens.append(sep_token)
|
||||||
|
segment_ids.append(sequence_b_segment_id)
|
||||||
|
p_mask.append(1)
|
||||||
|
|
||||||
|
# CLS token at the end
|
||||||
|
if cls_token_at_end:
|
||||||
|
tokens.append(cls_token)
|
||||||
|
segment_ids.append(cls_token_segment_id)
|
||||||
|
p_mask.append(0)
|
||||||
|
cls_index = len(tokens) - 1 # Index of classification token
|
||||||
|
|
||||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||||
|
|
||||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
# tokens are attended to.
|
# tokens are attended to.
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||||
|
|
||||||
# Zero-pad up to the sequence length.
|
# Zero-pad up to the sequence length.
|
||||||
while len(input_ids) < max_seq_length:
|
while len(input_ids) < max_seq_length:
|
||||||
input_ids.append(0)
|
input_ids.append(pad_token)
|
||||||
input_mask.append(0)
|
input_mask.append(0 if mask_padding_with_zero else 1)
|
||||||
segment_ids.append(0)
|
segment_ids.append(pad_token_segment_id)
|
||||||
|
p_mask.append(1)
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
assert len(input_ids) == max_seq_length
|
||||||
assert len(input_mask) == max_seq_length
|
assert len(input_mask) == max_seq_length
|
||||||
assert len(segment_ids) == max_seq_length
|
assert len(segment_ids) == max_seq_length
|
||||||
|
|
||||||
|
span_is_impossible = example.is_impossible
|
||||||
start_position = None
|
start_position = None
|
||||||
end_position = None
|
end_position = None
|
||||||
if is_training and not example.is_impossible:
|
if is_training and not span_is_impossible:
|
||||||
# For training, if our document chunk does not contain an annotation
|
# For training, if our document chunk does not contain an annotation
|
||||||
# we throw it out, since there is nothing to predict.
|
# we throw it out, since there is nothing to predict.
|
||||||
doc_start = doc_span.start
|
doc_start = doc_span.start
|
||||||
@@ -289,13 +337,16 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
if out_of_span:
|
if out_of_span:
|
||||||
start_position = 0
|
start_position = 0
|
||||||
end_position = 0
|
end_position = 0
|
||||||
|
span_is_impossible = True
|
||||||
else:
|
else:
|
||||||
doc_offset = len(query_tokens) + 2
|
doc_offset = len(query_tokens) + 2
|
||||||
start_position = tok_start_position - doc_start + doc_offset
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
if is_training and example.is_impossible:
|
|
||||||
start_position = 0
|
if is_training and span_is_impossible:
|
||||||
end_position = 0
|
start_position = cls_index
|
||||||
|
end_position = cls_index
|
||||||
|
|
||||||
if example_index < 20:
|
if example_index < 20:
|
||||||
logger.info("*** Example ***")
|
logger.info("*** Example ***")
|
||||||
logger.info("unique_id: %s" % (unique_id))
|
logger.info("unique_id: %s" % (unique_id))
|
||||||
@@ -312,9 +363,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||||
logger.info(
|
logger.info(
|
||||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||||
if is_training and example.is_impossible:
|
if is_training and span_is_impossible:
|
||||||
logger.info("impossible example")
|
logger.info("impossible example")
|
||||||
if is_training and not example.is_impossible:
|
if is_training and not span_is_impossible:
|
||||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||||
logger.info("start_position: %d" % (start_position))
|
logger.info("start_position: %d" % (start_position))
|
||||||
logger.info("end_position: %d" % (end_position))
|
logger.info("end_position: %d" % (end_position))
|
||||||
@@ -332,9 +383,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
input_mask=input_mask,
|
input_mask=input_mask,
|
||||||
segment_ids=segment_ids,
|
segment_ids=segment_ids,
|
||||||
|
cls_index=cls_index,
|
||||||
|
p_mask=p_mask,
|
||||||
start_position=start_position,
|
start_position=start_position,
|
||||||
end_position=end_position,
|
end_position=end_position,
|
||||||
is_impossible=example.is_impossible))
|
is_impossible=span_is_impossible))
|
||||||
unique_id += 1
|
unique_id += 1
|
||||||
|
|
||||||
return features
|
return features
|
||||||
@@ -417,7 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
|||||||
RawResult = collections.namedtuple("RawResult",
|
RawResult = collections.namedtuple("RawResult",
|
||||||
["unique_id", "start_logits", "end_logits"])
|
["unique_id", "start_logits", "end_logits"])
|
||||||
|
|
||||||
|
|
||||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||||
max_answer_length, do_lower_case, output_prediction_file,
|
max_answer_length, do_lower_case, output_prediction_file,
|
||||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||||
@@ -612,6 +664,182 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
|||||||
return all_predictions
|
return all_predictions
|
||||||
|
|
||||||
|
|
||||||
|
# For XLNet (and XLM which uses the same head)
|
||||||
|
RawResultExtended = collections.namedtuple("RawResultExtended",
|
||||||
|
["unique_id", "start_top_log_probs", "start_top_index",
|
||||||
|
"end_top_log_probs", "end_top_index", "cls_logits"])
|
||||||
|
|
||||||
|
|
||||||
|
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||||
|
max_answer_length, output_prediction_file,
|
||||||
|
output_nbest_file,
|
||||||
|
output_null_log_odds_file, orig_data,
|
||||||
|
start_n_top, end_n_top, version_2_with_negative):
|
||||||
|
""" XLNet write prediction logic (more complex than Bert's).
|
||||||
|
Write final predictions to the json file and log-odds of null if needed.
|
||||||
|
|
||||||
|
Requires utils_squad_evaluate.py
|
||||||
|
"""
|
||||||
|
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
|
"PrelimPrediction",
|
||||||
|
["feature_index", "start_index", "end_index",
|
||||||
|
"start_log_prob", "end_log_prob"])
|
||||||
|
|
||||||
|
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||||
|
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||||
|
|
||||||
|
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||||
|
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||||
|
|
||||||
|
example_index_to_features = collections.defaultdict(list)
|
||||||
|
for feature in all_features:
|
||||||
|
example_index_to_features[feature.example_index].append(feature)
|
||||||
|
|
||||||
|
unique_id_to_result = {}
|
||||||
|
for result in all_results:
|
||||||
|
unique_id_to_result[result.unique_id] = result
|
||||||
|
|
||||||
|
all_predictions = collections.OrderedDict()
|
||||||
|
all_nbest_json = collections.OrderedDict()
|
||||||
|
scores_diff_json = collections.OrderedDict()
|
||||||
|
|
||||||
|
for (example_index, example) in enumerate(all_examples):
|
||||||
|
features = example_index_to_features[example_index]
|
||||||
|
|
||||||
|
prelim_predictions = []
|
||||||
|
# keep track of the minimum score of null start+end of position 0
|
||||||
|
score_null = 1000000 # large and positive
|
||||||
|
|
||||||
|
for (feature_index, feature) in enumerate(features):
|
||||||
|
result = unique_id_to_result[feature.unique_id]
|
||||||
|
|
||||||
|
cur_null_score = result.cls_logits
|
||||||
|
|
||||||
|
# if we could have irrelevant answers, get the min score of irrelevant
|
||||||
|
score_null = min(score_null, cur_null_score)
|
||||||
|
|
||||||
|
for i in range(start_n_top):
|
||||||
|
for j in range(end_n_top):
|
||||||
|
start_log_prob = result.start_top_log_probs[i]
|
||||||
|
start_index = result.start_top_index[i]
|
||||||
|
|
||||||
|
j_index = i * end_n_top + j
|
||||||
|
|
||||||
|
end_log_prob = result.end_top_log_probs[j_index]
|
||||||
|
end_index = result.end_top_index[j_index]
|
||||||
|
|
||||||
|
# We could hypothetically create invalid predictions, e.g., predict
|
||||||
|
# that the start of the span is in the question. We throw out all
|
||||||
|
# invalid predictions.
|
||||||
|
if start_index >= feature.paragraph_len - 1:
|
||||||
|
continue
|
||||||
|
if end_index >= feature.paragraph_len - 1:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not feature.token_is_max_context.get(start_index, False):
|
||||||
|
continue
|
||||||
|
if end_index < start_index:
|
||||||
|
continue
|
||||||
|
length = end_index - start_index + 1
|
||||||
|
if length > max_answer_length:
|
||||||
|
continue
|
||||||
|
|
||||||
|
prelim_predictions.append(
|
||||||
|
_PrelimPrediction(
|
||||||
|
feature_index=feature_index,
|
||||||
|
start_index=start_index,
|
||||||
|
end_index=end_index,
|
||||||
|
start_log_prob=start_log_prob,
|
||||||
|
end_log_prob=end_log_prob))
|
||||||
|
|
||||||
|
prelim_predictions = sorted(
|
||||||
|
prelim_predictions,
|
||||||
|
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||||
|
reverse=True)
|
||||||
|
|
||||||
|
seen_predictions = {}
|
||||||
|
nbest = []
|
||||||
|
for pred in prelim_predictions:
|
||||||
|
if len(nbest) >= n_best_size:
|
||||||
|
break
|
||||||
|
feature = features[pred.feature_index]
|
||||||
|
|
||||||
|
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||||
|
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||||
|
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||||
|
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||||
|
|
||||||
|
paragraph_text = example.paragraph_text
|
||||||
|
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||||
|
|
||||||
|
if final_text in seen_predictions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
seen_predictions[final_text] = True
|
||||||
|
|
||||||
|
nbest.append(
|
||||||
|
_NbestPrediction(
|
||||||
|
text=final_text,
|
||||||
|
start_log_prob=pred.start_log_prob,
|
||||||
|
end_log_prob=pred.end_log_prob))
|
||||||
|
|
||||||
|
# In very rare edge cases we could have no valid predictions. So we
|
||||||
|
# just create a nonce prediction in this case to avoid failure.
|
||||||
|
if not nbest:
|
||||||
|
nbest.append(
|
||||||
|
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||||
|
end_log_prob=-1e6))
|
||||||
|
|
||||||
|
total_scores = []
|
||||||
|
best_non_null_entry = None
|
||||||
|
for entry in nbest:
|
||||||
|
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
||||||
|
if not best_non_null_entry:
|
||||||
|
best_non_null_entry = entry
|
||||||
|
|
||||||
|
probs = _compute_softmax(total_scores)
|
||||||
|
|
||||||
|
nbest_json = []
|
||||||
|
for (i, entry) in enumerate(nbest):
|
||||||
|
output = collections.OrderedDict()
|
||||||
|
output["text"] = entry.text
|
||||||
|
output["probability"] = probs[i]
|
||||||
|
output["start_log_prob"] = entry.start_log_prob
|
||||||
|
output["end_log_prob"] = entry.end_log_prob
|
||||||
|
nbest_json.append(output)
|
||||||
|
|
||||||
|
assert len(nbest_json) >= 1
|
||||||
|
assert best_non_null_entry is not None
|
||||||
|
|
||||||
|
score_diff = score_null
|
||||||
|
scores_diff_json[example.qas_id] = score_diff
|
||||||
|
# note(zhiliny): always predict best_non_null_entry
|
||||||
|
# and the evaluation script will search for the best threshold
|
||||||
|
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||||
|
|
||||||
|
all_nbest_json[example.qas_id] = nbest_json
|
||||||
|
|
||||||
|
with open(output_prediction_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||||
|
|
||||||
|
with open(output_nbest_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
if version_2_with_negative:
|
||||||
|
with open(output_null_log_odds_file, "w") as writer:
|
||||||
|
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||||
|
|
||||||
|
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||||
|
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||||
|
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||||
|
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
|
||||||
|
out_eval = {}
|
||||||
|
|
||||||
|
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
|
||||||
|
|
||||||
|
return out_eval
|
||||||
|
|
||||||
|
|
||||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||||
"""Project the tokenized prediction back to the original text."""
|
"""Project the tokenized prediction back to the original text."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
""" Official evaluation script for SQuAD version 2.0.
|
""" Official evaluation script for SQuAD version 2.0.
|
||||||
|
Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
|
||||||
|
|
||||||
In addition to basic functionality, we also compute additional statistics and
|
In addition to basic functionality, we also compute additional statistics and
|
||||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||||
@@ -232,6 +233,36 @@ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
|||||||
best_thresh = na_probs[qid]
|
best_thresh = na_probs[qid]
|
||||||
return 100.0 * best_score / len(scores), best_thresh
|
return 100.0 * best_score / len(scores), best_thresh
|
||||||
|
|
||||||
|
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||||
|
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||||
|
cur_score = num_no_ans
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = 0.0
|
||||||
|
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||||
|
for i, qid in enumerate(qid_list):
|
||||||
|
if qid not in scores: continue
|
||||||
|
if qid_to_has_ans[qid]:
|
||||||
|
diff = scores[qid]
|
||||||
|
else:
|
||||||
|
if preds[qid]:
|
||||||
|
diff = -1
|
||||||
|
else:
|
||||||
|
diff = 0
|
||||||
|
cur_score += diff
|
||||||
|
if cur_score > best_score:
|
||||||
|
best_score = cur_score
|
||||||
|
best_thresh = na_probs[qid]
|
||||||
|
|
||||||
|
has_ans_score, has_ans_cnt = 0, 0
|
||||||
|
for qid in qid_list:
|
||||||
|
if not qid_to_has_ans[qid]: continue
|
||||||
|
has_ans_cnt += 1
|
||||||
|
|
||||||
|
if qid not in scores: continue
|
||||||
|
has_ans_score += scores[qid]
|
||||||
|
|
||||||
|
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
||||||
|
|
||||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
@@ -240,6 +271,16 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
|
|||||||
main_eval['best_f1'] = best_f1
|
main_eval['best_f1'] = best_f1
|
||||||
main_eval['best_f1_thresh'] = f1_thresh
|
main_eval['best_f1_thresh'] = f1_thresh
|
||||||
|
|
||||||
|
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||||
|
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||||
|
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||||
|
main_eval['best_exact'] = best_exact
|
||||||
|
main_eval['best_exact_thresh'] = exact_thresh
|
||||||
|
main_eval['best_f1'] = best_f1
|
||||||
|
main_eval['best_f1_thresh'] = f1_thresh
|
||||||
|
main_eval['has_ans_exact'] = has_ans_exact
|
||||||
|
main_eval['has_ans_f1'] = has_ans_f1
|
||||||
|
|
||||||
def main(OPTS):
|
def main(OPTS):
|
||||||
with open(OPTS.data_file) as f:
|
with open(OPTS.data_file) as f:
|
||||||
dataset_json = json.load(f)
|
dataset_json = json.load(f)
|
||||||
|
|||||||
@@ -493,8 +493,9 @@ class PoolerStartLogits(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, p_mask=None):
|
def forward(self, hidden_states, p_mask=None):
|
||||||
""" Args:
|
""" Args:
|
||||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
||||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||||
|
1.0 means token should be masked.
|
||||||
"""
|
"""
|
||||||
x = self.dense(hidden_states).squeeze(-1)
|
x = self.dense(hidden_states).squeeze(-1)
|
||||||
|
|
||||||
@@ -516,11 +517,16 @@ class PoolerEndLogits(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
||||||
""" Args:
|
""" Args:
|
||||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
One of ``start_states``, ``start_positions`` should be not None.
|
||||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
If both are set, ``start_positions`` overrides ``start_states``.
|
||||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
|
||||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
hidden states of the first tokens for the labeled span.
|
||||||
|
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
|
position of the first token for the labeled span:
|
||||||
|
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||||
|
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||||
|
1.0 means token should be masked.
|
||||||
"""
|
"""
|
||||||
slen, hsz = hidden_states.shape[-2:]
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||||
@@ -549,13 +555,21 @@ class PoolerAnswerClass(nn.Module):
|
|||||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||||
|
|
||||||
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
||||||
""" Args:
|
"""
|
||||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
Args:
|
||||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
One of ``start_states``, ``start_positions`` should be not None.
|
||||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
If both are set, ``start_positions`` overrides ``start_states``.
|
||||||
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
|
|
||||||
|
|
||||||
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
|
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
||||||
|
hidden states of the first tokens for the labeled span.
|
||||||
|
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
|
position of the first token for the labeled span.
|
||||||
|
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||||
|
position of the CLS token. If None, take the last token.
|
||||||
|
|
||||||
|
note(Original repo):
|
||||||
|
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||||
|
for each sample
|
||||||
"""
|
"""
|
||||||
slen, hsz = hidden_states.shape[-2:]
|
slen, hsz = hidden_states.shape[-2:]
|
||||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||||
@@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class SQuADHead(nn.Module):
|
class SQuADHead(nn.Module):
|
||||||
""" A SQuAD head inspired by XLNet.
|
r""" A SQuAD head inspired by XLNet.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
||||||
|
hidden states of sequence tokens
|
||||||
|
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
|
position of the first token for the labeled span.
|
||||||
|
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
|
position of the last token for the labeled span.
|
||||||
|
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||||
|
position of the CLS token. If None, take the last token.
|
||||||
|
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||||
|
Whether the question has a possible answer in the paragraph or not.
|
||||||
|
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||||
|
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||||
|
1.0 means token should be masked.
|
||||||
|
|
||||||
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
|
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
|
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||||
|
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||||
|
Sequence of hidden-states at the last layer of the model.
|
||||||
|
**mems**:
|
||||||
|
list of ``torch.FloatTensor`` (one for each layer):
|
||||||
|
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||||
|
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||||
"""
|
"""
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(SQuADHead, self).__init__()
|
super(SQuADHead, self).__init__()
|
||||||
@@ -590,8 +632,6 @@ class SQuADHead(nn.Module):
|
|||||||
|
|
||||||
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
||||||
cls_index=None, is_impossible=None, p_mask=None):
|
cls_index=None, is_impossible=None, p_mask=None):
|
||||||
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
|
||||||
"""
|
|
||||||
outputs = ()
|
outputs = ()
|
||||||
|
|
||||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||||
@@ -618,9 +658,8 @@ class SQuADHead(nn.Module):
|
|||||||
|
|
||||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||||
total_loss += cls_loss * 0.5
|
total_loss += cls_loss * 0.5
|
||||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
|
||||||
else:
|
outputs = (total_loss,) + outputs
|
||||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# during inference, compute the end logits based on beam search
|
# during inference, compute the end logits based on beam search
|
||||||
@@ -647,7 +686,7 @@ class SQuADHead(nn.Module):
|
|||||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||||
|
|
||||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
|
# or (if labels are provided) (total_loss,)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1162,8 +1162,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
|||||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
|
||||||
|
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||||
|
|
||||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||||
|
|||||||
Reference in New Issue
Block a user