updating squad for compatibility with XLNet
This commit is contained in:
@@ -41,7 +41,9 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig,
|
||||
|
||||
from pytorch_transformers import AdamW, WarmupLinearSchedule
|
||||
|
||||
from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions
|
||||
from utils_squad import (read_squad_examples, convert_examples_to_features,
|
||||
RawResult, write_predictions,
|
||||
RawResultExtended, write_predictions_extended)
|
||||
|
||||
# The follwing import is the official SQuAD evaluation script (2.0).
|
||||
# You can remove it from the dependencies if you are using this script outside of the library
|
||||
@@ -66,6 +68,8 @@ def set_seed(args):
|
||||
if args.n_gpu > 0:
|
||||
torch.cuda.manual_seed_all(args.seed)
|
||||
|
||||
def to_list(tensor):
|
||||
return tensor.detach().cpu().tolist()
|
||||
|
||||
def train(args, train_dataset, model, tokenizer):
|
||||
""" Train the model """
|
||||
@@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer):
|
||||
model.train()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
inputs = {'input_ids': batch[0],
|
||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||
'attention_mask': batch[2],
|
||||
'start_positions': batch[3],
|
||||
'end_positions': batch[4]}
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[5],
|
||||
'p_mask': batch[6]})
|
||||
ouputs = model(**inputs)
|
||||
loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc)
|
||||
|
||||
@@ -197,30 +204,49 @@ def evaluate(args, model, tokenizer, prefix=""):
|
||||
for batch in tqdm(eval_dataloader, desc="Evaluating"):
|
||||
model.eval()
|
||||
batch = tuple(t.to(args.device) for t in batch)
|
||||
example_indices = batch[3]
|
||||
with torch.no_grad():
|
||||
inputs = {'input_ids': batch[0],
|
||||
'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids
|
||||
'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids
|
||||
'attention_mask': batch[2]}
|
||||
example_indices = batch[3]
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
inputs.update({'cls_index': batch[4],
|
||||
'p_mask': batch[5]})
|
||||
outputs = model(**inputs)
|
||||
batch_start_logits, batch_end_logits = outputs[:2]
|
||||
|
||||
for i, example_index in enumerate(example_indices):
|
||||
start_logits = batch_start_logits[i].detach().cpu().tolist()
|
||||
end_logits = batch_end_logits[i].detach().cpu().tolist()
|
||||
eval_feature = features[example_index.item()]
|
||||
unique_id = int(eval_feature.unique_id)
|
||||
all_results.append(RawResult(unique_id=unique_id,
|
||||
start_logits=start_logits,
|
||||
end_logits=end_logits))
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
result = RawResultExtended(unique_id = unique_id,
|
||||
start_top_log_probs = to_list(outputs[0][i]),
|
||||
start_top_index = to_list(outputs[1][i]),
|
||||
end_top_log_probs = to_list(outputs[2][i]),
|
||||
end_top_index = to_list(outputs[3][i]),
|
||||
cls_logits = to_list(outputs[4][i]))
|
||||
else:
|
||||
result = RawResult(unique_id = unique_id,
|
||||
start_logits = to_list(outputs[0][i]),
|
||||
end_logits = to_list(outputs[1][i]))
|
||||
all_results.append(result)
|
||||
|
||||
# Compute predictions
|
||||
output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix))
|
||||
output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix))
|
||||
output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix))
|
||||
write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length,
|
||||
args.do_lower_case, output_prediction_file, output_nbest_file,
|
||||
output_null_log_odds_file, args.verbose_logging,
|
||||
|
||||
if args.model_type in ['xlnet', 'xlm']:
|
||||
# XLNet uses a more complex post-processing procedure
|
||||
write_predictions_extended(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.predict_file,
|
||||
args.start_n_top, args.end_n_top, args.version_2_with_negative)
|
||||
else:
|
||||
write_predictions(examples, features, all_results, args.n_best_size,
|
||||
args.max_answer_length, args.do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, args.verbose_logging,
|
||||
args.version_2_with_negative, args.null_score_diff_threshold)
|
||||
|
||||
# Evaluate with the official SQuAD script
|
||||
@@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
|
||||
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
|
||||
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
|
||||
all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long)
|
||||
all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float)
|
||||
if evaluate:
|
||||
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_example_index, all_cls_index, all_p_mask)
|
||||
else:
|
||||
all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long)
|
||||
all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions)
|
||||
dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
|
||||
all_start_positions, all_end_positions,
|
||||
all_cls_index, all_p_mask)
|
||||
|
||||
if output_examples:
|
||||
return dataset, examples, features
|
||||
|
||||
@@ -26,6 +26,9 @@ from io import open
|
||||
|
||||
from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize
|
||||
|
||||
# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method)
|
||||
from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -82,6 +85,8 @@ class InputFeatures(object):
|
||||
input_ids,
|
||||
input_mask,
|
||||
segment_ids,
|
||||
cls_index,
|
||||
p_mask,
|
||||
start_position=None,
|
||||
end_position=None,
|
||||
is_impossible=None):
|
||||
@@ -94,6 +99,8 @@ class InputFeatures(object):
|
||||
self.input_ids = input_ids
|
||||
self.input_mask = input_mask
|
||||
self.segment_ids = segment_ids
|
||||
self.cls_index = cls_index
|
||||
self.p_mask = p_mask
|
||||
self.start_position = start_position
|
||||
self.end_position = end_position
|
||||
self.is_impossible = is_impossible
|
||||
@@ -178,13 +185,25 @@ def read_squad_examples(input_file, is_training, version_2_with_negative):
|
||||
|
||||
|
||||
def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
doc_stride, max_query_length, is_training):
|
||||
doc_stride, max_query_length, is_training,
|
||||
cls_token_at_end=False,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
# cnt_pos, cnt_neg = 0, 0
|
||||
# max_N, max_M = 1024, 1024
|
||||
# f = np.zeros((max_N, max_M), dtype=np.float32)
|
||||
|
||||
features = []
|
||||
for (example_index, example) in enumerate(examples):
|
||||
|
||||
# if example_index % 100 == 0:
|
||||
# logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg)
|
||||
|
||||
query_tokens = tokenizer.tokenize(example.question_text)
|
||||
|
||||
if len(query_tokens) > max_query_length:
|
||||
@@ -239,14 +258,30 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
token_to_orig_map = {}
|
||||
token_is_max_context = {}
|
||||
segment_ids = []
|
||||
tokens.append("[CLS]")
|
||||
segment_ids.append(0)
|
||||
|
||||
# p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
|
||||
# Original TF implem also keep the classification token (set to 0) (not sure why...)
|
||||
p_mask = []
|
||||
|
||||
# CLS token at the beginning
|
||||
if not cls_token_at_end:
|
||||
tokens.append(cls_token)
|
||||
segment_ids.append(cls_token_segment_id)
|
||||
p_mask.append(0)
|
||||
cls_index = 0
|
||||
|
||||
# Query
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(0)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(0)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# Paragraph
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
|
||||
@@ -255,29 +290,42 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(1)
|
||||
tokens.append("[SEP]")
|
||||
segment_ids.append(1)
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
p_mask.append(0)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# CLS token at the end
|
||||
if cls_token_at_end:
|
||||
tokens.append(cls_token)
|
||||
segment_ids.append(cls_token_segment_id)
|
||||
p_mask.append(0)
|
||||
cls_index = len(tokens) - 1 # Index of classification token
|
||||
|
||||
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
||||
|
||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||
# tokens are attended to.
|
||||
input_mask = [1] * len(input_ids)
|
||||
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||
|
||||
# Zero-pad up to the sequence length.
|
||||
while len(input_ids) < max_seq_length:
|
||||
input_ids.append(0)
|
||||
input_mask.append(0)
|
||||
segment_ids.append(0)
|
||||
input_ids.append(pad_token)
|
||||
input_mask.append(0 if mask_padding_with_zero else 1)
|
||||
segment_ids.append(pad_token_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
assert len(input_ids) == max_seq_length
|
||||
assert len(input_mask) == max_seq_length
|
||||
assert len(segment_ids) == max_seq_length
|
||||
|
||||
span_is_impossible = example.is_impossible
|
||||
start_position = None
|
||||
end_position = None
|
||||
if is_training and not example.is_impossible:
|
||||
if is_training and not span_is_impossible:
|
||||
# For training, if our document chunk does not contain an annotation
|
||||
# we throw it out, since there is nothing to predict.
|
||||
doc_start = doc_span.start
|
||||
@@ -289,13 +337,16 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
if out_of_span:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
span_is_impossible = True
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
if is_training and example.is_impossible:
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
|
||||
if is_training and span_is_impossible:
|
||||
start_position = cls_index
|
||||
end_position = cls_index
|
||||
|
||||
if example_index < 20:
|
||||
logger.info("*** Example ***")
|
||||
logger.info("unique_id: %s" % (unique_id))
|
||||
@@ -312,9 +363,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
"input_mask: %s" % " ".join([str(x) for x in input_mask]))
|
||||
logger.info(
|
||||
"segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
|
||||
if is_training and example.is_impossible:
|
||||
if is_training and span_is_impossible:
|
||||
logger.info("impossible example")
|
||||
if is_training and not example.is_impossible:
|
||||
if is_training and not span_is_impossible:
|
||||
answer_text = " ".join(tokens[start_position:(end_position + 1)])
|
||||
logger.info("start_position: %d" % (start_position))
|
||||
logger.info("end_position: %d" % (end_position))
|
||||
@@ -332,9 +383,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
input_ids=input_ids,
|
||||
input_mask=input_mask,
|
||||
segment_ids=segment_ids,
|
||||
cls_index=cls_index,
|
||||
p_mask=p_mask,
|
||||
start_position=start_position,
|
||||
end_position=end_position,
|
||||
is_impossible=example.is_impossible))
|
||||
is_impossible=span_is_impossible))
|
||||
unique_id += 1
|
||||
|
||||
return features
|
||||
@@ -417,7 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position):
|
||||
RawResult = collections.namedtuple("RawResult",
|
||||
["unique_id", "start_logits", "end_logits"])
|
||||
|
||||
|
||||
def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, do_lower_case, output_prediction_file,
|
||||
output_nbest_file, output_null_log_odds_file, verbose_logging,
|
||||
@@ -612,6 +664,182 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
|
||||
return all_predictions
|
||||
|
||||
|
||||
# For XLNet (and XLM which uses the same head)
|
||||
RawResultExtended = collections.namedtuple("RawResultExtended",
|
||||
["unique_id", "start_top_log_probs", "start_top_index",
|
||||
"end_top_log_probs", "end_top_index", "cls_logits"])
|
||||
|
||||
|
||||
def write_predictions_extended(all_examples, all_features, all_results, n_best_size,
|
||||
max_answer_length, output_prediction_file,
|
||||
output_nbest_file,
|
||||
output_null_log_odds_file, orig_data,
|
||||
start_n_top, end_n_top, version_2_with_negative):
|
||||
""" XLNet write prediction logic (more complex than Bert's).
|
||||
Write final predictions to the json file and log-odds of null if needed.
|
||||
|
||||
Requires utils_squad_evaluate.py
|
||||
"""
|
||||
_PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"PrelimPrediction",
|
||||
["feature_index", "start_index", "end_index",
|
||||
"start_log_prob", "end_log_prob"])
|
||||
|
||||
_NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name
|
||||
"NbestPrediction", ["text", "start_log_prob", "end_log_prob"])
|
||||
|
||||
logger.info("Writing predictions to: %s", output_prediction_file)
|
||||
# logger.info("Writing nbest to: %s" % (output_nbest_file))
|
||||
|
||||
example_index_to_features = collections.defaultdict(list)
|
||||
for feature in all_features:
|
||||
example_index_to_features[feature.example_index].append(feature)
|
||||
|
||||
unique_id_to_result = {}
|
||||
for result in all_results:
|
||||
unique_id_to_result[result.unique_id] = result
|
||||
|
||||
all_predictions = collections.OrderedDict()
|
||||
all_nbest_json = collections.OrderedDict()
|
||||
scores_diff_json = collections.OrderedDict()
|
||||
|
||||
for (example_index, example) in enumerate(all_examples):
|
||||
features = example_index_to_features[example_index]
|
||||
|
||||
prelim_predictions = []
|
||||
# keep track of the minimum score of null start+end of position 0
|
||||
score_null = 1000000 # large and positive
|
||||
|
||||
for (feature_index, feature) in enumerate(features):
|
||||
result = unique_id_to_result[feature.unique_id]
|
||||
|
||||
cur_null_score = result.cls_logits
|
||||
|
||||
# if we could have irrelevant answers, get the min score of irrelevant
|
||||
score_null = min(score_null, cur_null_score)
|
||||
|
||||
for i in range(start_n_top):
|
||||
for j in range(end_n_top):
|
||||
start_log_prob = result.start_top_log_probs[i]
|
||||
start_index = result.start_top_index[i]
|
||||
|
||||
j_index = i * end_n_top + j
|
||||
|
||||
end_log_prob = result.end_top_log_probs[j_index]
|
||||
end_index = result.end_top_index[j_index]
|
||||
|
||||
# We could hypothetically create invalid predictions, e.g., predict
|
||||
# that the start of the span is in the question. We throw out all
|
||||
# invalid predictions.
|
||||
if start_index >= feature.paragraph_len - 1:
|
||||
continue
|
||||
if end_index >= feature.paragraph_len - 1:
|
||||
continue
|
||||
|
||||
if not feature.token_is_max_context.get(start_index, False):
|
||||
continue
|
||||
if end_index < start_index:
|
||||
continue
|
||||
length = end_index - start_index + 1
|
||||
if length > max_answer_length:
|
||||
continue
|
||||
|
||||
prelim_predictions.append(
|
||||
_PrelimPrediction(
|
||||
feature_index=feature_index,
|
||||
start_index=start_index,
|
||||
end_index=end_index,
|
||||
start_log_prob=start_log_prob,
|
||||
end_log_prob=end_log_prob))
|
||||
|
||||
prelim_predictions = sorted(
|
||||
prelim_predictions,
|
||||
key=lambda x: (x.start_log_prob + x.end_log_prob),
|
||||
reverse=True)
|
||||
|
||||
seen_predictions = {}
|
||||
nbest = []
|
||||
for pred in prelim_predictions:
|
||||
if len(nbest) >= n_best_size:
|
||||
break
|
||||
feature = features[pred.feature_index]
|
||||
|
||||
tok_start_to_orig_index = feature.tok_start_to_orig_index
|
||||
tok_end_to_orig_index = feature.tok_end_to_orig_index
|
||||
start_orig_pos = tok_start_to_orig_index[pred.start_index]
|
||||
end_orig_pos = tok_end_to_orig_index[pred.end_index]
|
||||
|
||||
paragraph_text = example.paragraph_text
|
||||
final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip()
|
||||
|
||||
if final_text in seen_predictions:
|
||||
continue
|
||||
|
||||
seen_predictions[final_text] = True
|
||||
|
||||
nbest.append(
|
||||
_NbestPrediction(
|
||||
text=final_text,
|
||||
start_log_prob=pred.start_log_prob,
|
||||
end_log_prob=pred.end_log_prob))
|
||||
|
||||
# In very rare edge cases we could have no valid predictions. So we
|
||||
# just create a nonce prediction in this case to avoid failure.
|
||||
if not nbest:
|
||||
nbest.append(
|
||||
_NbestPrediction(text="", start_log_prob=-1e6,
|
||||
end_log_prob=-1e6))
|
||||
|
||||
total_scores = []
|
||||
best_non_null_entry = None
|
||||
for entry in nbest:
|
||||
total_scores.append(entry.start_log_prob + entry.end_log_prob)
|
||||
if not best_non_null_entry:
|
||||
best_non_null_entry = entry
|
||||
|
||||
probs = _compute_softmax(total_scores)
|
||||
|
||||
nbest_json = []
|
||||
for (i, entry) in enumerate(nbest):
|
||||
output = collections.OrderedDict()
|
||||
output["text"] = entry.text
|
||||
output["probability"] = probs[i]
|
||||
output["start_log_prob"] = entry.start_log_prob
|
||||
output["end_log_prob"] = entry.end_log_prob
|
||||
nbest_json.append(output)
|
||||
|
||||
assert len(nbest_json) >= 1
|
||||
assert best_non_null_entry is not None
|
||||
|
||||
score_diff = score_null
|
||||
scores_diff_json[example.qas_id] = score_diff
|
||||
# note(zhiliny): always predict best_non_null_entry
|
||||
# and the evaluation script will search for the best threshold
|
||||
all_predictions[example.qas_id] = best_non_null_entry.text
|
||||
|
||||
all_nbest_json[example.qas_id] = nbest_json
|
||||
|
||||
with open(output_prediction_file, "w") as writer:
|
||||
writer.write(json.dumps(all_predictions, indent=4) + "\n")
|
||||
|
||||
with open(output_nbest_file, "w") as writer:
|
||||
writer.write(json.dumps(all_nbest_json, indent=4) + "\n")
|
||||
|
||||
if version_2_with_negative:
|
||||
with open(output_null_log_odds_file, "w") as writer:
|
||||
writer.write(json.dumps(scores_diff_json, indent=4) + "\n")
|
||||
|
||||
qid_to_has_ans = make_qid_to_has_ans(orig_data)
|
||||
has_ans_qids = [k for k, v in qid_to_has_ans.items() if v]
|
||||
no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v]
|
||||
exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions)
|
||||
out_eval = {}
|
||||
|
||||
find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans)
|
||||
|
||||
return out_eval
|
||||
|
||||
|
||||
def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
|
||||
"""Project the tokenized prediction back to the original text."""
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
""" Official evaluation script for SQuAD version 2.0.
|
||||
Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0
|
||||
|
||||
In addition to basic functionality, we also compute additional statistics and
|
||||
plot precision-recall curves if an additional na_prob.json file is provided.
|
||||
@@ -232,6 +233,36 @@ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans):
|
||||
best_thresh = na_probs[qid]
|
||||
return 100.0 * best_score / len(scores), best_thresh
|
||||
|
||||
def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans):
|
||||
num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k])
|
||||
cur_score = num_no_ans
|
||||
best_score = cur_score
|
||||
best_thresh = 0.0
|
||||
qid_list = sorted(na_probs, key=lambda k: na_probs[k])
|
||||
for i, qid in enumerate(qid_list):
|
||||
if qid not in scores: continue
|
||||
if qid_to_has_ans[qid]:
|
||||
diff = scores[qid]
|
||||
else:
|
||||
if preds[qid]:
|
||||
diff = -1
|
||||
else:
|
||||
diff = 0
|
||||
cur_score += diff
|
||||
if cur_score > best_score:
|
||||
best_score = cur_score
|
||||
best_thresh = na_probs[qid]
|
||||
|
||||
has_ans_score, has_ans_cnt = 0, 0
|
||||
for qid in qid_list:
|
||||
if not qid_to_has_ans[qid]: continue
|
||||
has_ans_cnt += 1
|
||||
|
||||
if qid not in scores: continue
|
||||
has_ans_score += scores[qid]
|
||||
|
||||
return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt
|
||||
|
||||
def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
@@ -240,6 +271,16 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
|
||||
def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans):
|
||||
best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans)
|
||||
best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans)
|
||||
main_eval['best_exact'] = best_exact
|
||||
main_eval['best_exact_thresh'] = exact_thresh
|
||||
main_eval['best_f1'] = best_f1
|
||||
main_eval['best_f1_thresh'] = f1_thresh
|
||||
main_eval['has_ans_exact'] = has_ans_exact
|
||||
main_eval['has_ans_f1'] = has_ans_f1
|
||||
|
||||
def main(OPTS):
|
||||
with open(OPTS.data_file) as f:
|
||||
dataset_json = json.load(f)
|
||||
|
||||
@@ -493,8 +493,9 @@ class PoolerStartLogits(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, p_mask=None):
|
||||
""" Args:
|
||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)`
|
||||
invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
x = self.dense(hidden_states).squeeze(-1)
|
||||
|
||||
@@ -516,11 +517,16 @@ class PoolerEndLogits(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None):
|
||||
""" Args:
|
||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
||||
`p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS)
|
||||
shape [batch_size, seq_len]. 1.0 means token should be masked.
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to hidden_states
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span:
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
@@ -549,13 +555,21 @@ class PoolerAnswerClass(nn.Module):
|
||||
self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
|
||||
|
||||
def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None):
|
||||
""" Args:
|
||||
One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states.
|
||||
`start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states.
|
||||
`start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size].
|
||||
`cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token.
|
||||
"""
|
||||
Args:
|
||||
One of ``start_states``, ``start_positions`` should be not None.
|
||||
If both are set, ``start_positions`` overrides ``start_states``.
|
||||
|
||||
# note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample
|
||||
**start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``.
|
||||
hidden states of the first tokens for the labeled span.
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
|
||||
note(Original repo):
|
||||
no dependency on end_feature so that we can obtain one single `cls_logits`
|
||||
for each sample
|
||||
"""
|
||||
slen, hsz = hidden_states.shape[-2:]
|
||||
assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None"
|
||||
@@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module):
|
||||
|
||||
|
||||
class SQuADHead(nn.Module):
|
||||
""" A SQuAD head inspired by XLNet.
|
||||
r""" A SQuAD head inspired by XLNet.
|
||||
|
||||
Parameters:
|
||||
config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model.
|
||||
|
||||
Inputs:
|
||||
**hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)``
|
||||
hidden states of sequence tokens
|
||||
**start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the first token for the labeled span.
|
||||
**end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
position of the last token for the labeled span.
|
||||
**cls_index**: torch.LongTensor of shape ``(batch_size,)``
|
||||
position of the CLS token. If None, take the last token.
|
||||
**is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)``
|
||||
Whether the question has a possible answer in the paragraph or not.
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)``
|
||||
Mask of invalid position such as query and special symbols (PAD, SEP, CLS)
|
||||
1.0 means token should be masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses.
|
||||
**last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
|
||||
Sequence of hidden-states at the last layer of the model.
|
||||
**mems**:
|
||||
list of ``torch.FloatTensor`` (one for each layer):
|
||||
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
|
||||
(see `mems` input above). Can be used to speed up sequential decoding and attend to longer context.
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(SQuADHead, self).__init__()
|
||||
@@ -590,8 +632,6 @@ class SQuADHead(nn.Module):
|
||||
|
||||
def forward(self, hidden_states, start_positions=None, end_positions=None,
|
||||
cls_index=None, is_impossible=None, p_mask=None):
|
||||
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
|
||||
"""
|
||||
outputs = ()
|
||||
|
||||
start_logits = self.start_logits(hidden_states, p_mask=p_mask)
|
||||
@@ -618,9 +658,8 @@ class SQuADHead(nn.Module):
|
||||
|
||||
# note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
|
||||
total_loss += cls_loss * 0.5
|
||||
outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs
|
||||
else:
|
||||
outputs = (total_loss, start_logits, end_logits) + outputs
|
||||
|
||||
outputs = (total_loss,) + outputs
|
||||
|
||||
else:
|
||||
# during inference, compute the end logits based on beam search
|
||||
@@ -647,7 +686,7 @@ class SQuADHead(nn.Module):
|
||||
outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs
|
||||
|
||||
# return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits
|
||||
# or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits)
|
||||
# or (if labels are provided) (total_loss,)
|
||||
return outputs
|
||||
|
||||
|
||||
|
||||
@@ -1162,8 +1162,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
|
||||
Labels whether a question has an answer or no answer (SQuAD 2.0)
|
||||
**cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
|
||||
Labels for position (index) of the classification token to use as input for computing plausibility of the answer.
|
||||
**p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...)
|
||||
**p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
|
||||
Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...).
|
||||
1.0 means token should be masked. 0.0 mean token is not masked.
|
||||
|
||||
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
|
||||
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
|
||||
|
||||
Reference in New Issue
Block a user