From 3b469cb4229c983ee8b4fed58284742f6ac93f9a Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 15 Jul 2019 15:28:37 +0200 Subject: [PATCH] updating squad for compatibility with XLNet --- examples/run_squad.py | 67 ++++-- examples/utils_squad.py | 270 +++++++++++++++++++++++-- examples/utils_squad_evaluate.py | 43 +++- pytorch_transformers/modeling_utils.py | 79 ++++++-- pytorch_transformers/modeling_xlnet.py | 5 +- 5 files changed, 402 insertions(+), 62 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 24f00e0518..2025217454 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -41,7 +41,9 @@ from pytorch_transformers import (WEIGHTS_NAME, BertConfig, from pytorch_transformers import AdamW, WarmupLinearSchedule -from utils_squad import read_squad_examples, convert_examples_to_features, RawResult, write_predictions +from utils_squad import (read_squad_examples, convert_examples_to_features, + RawResult, write_predictions, + RawResultExtended, write_predictions_extended) # The follwing import is the official SQuAD evaluation script (2.0). # You can remove it from the dependencies if you are using this script outside of the library @@ -66,6 +68,8 @@ def set_seed(args): if args.n_gpu > 0: torch.cuda.manual_seed_all(args.seed) +def to_list(tensor): + return tensor.detach().cpu().tolist() def train(args, train_dataset, model, tokenizer): """ Train the model """ @@ -118,10 +122,13 @@ def train(args, train_dataset, model, tokenizer): model.train() batch = tuple(t.to(args.device) for t in batch) inputs = {'input_ids': batch[0], - 'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids + 'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids 'attention_mask': batch[2], 'start_positions': batch[3], 'end_positions': batch[4]} + if args.model_type in ['xlnet', 'xlm']: + inputs.update({'cls_index': batch[5], + 'p_mask': batch[6]}) ouputs = model(**inputs) loss = ouputs[0] # model outputs are always tuple in pytorch-transformers (see doc) @@ -197,31 +204,50 @@ def evaluate(args, model, tokenizer, prefix=""): for batch in tqdm(eval_dataloader, desc="Evaluating"): model.eval() batch = tuple(t.to(args.device) for t in batch) - example_indices = batch[3] with torch.no_grad(): inputs = {'input_ids': batch[0], - 'token_type_ids': batch[1] if args.model_type in ['bert', 'xlnet'] else None, # XLM don't use segment_ids - 'attention_mask': batch[2]} + 'token_type_ids': None if args.model_type == 'xlm' else batch[1], # XLM don't use segment_ids + 'attention_mask': batch[2]} + example_indices = batch[3] + if args.model_type in ['xlnet', 'xlm']: + inputs.update({'cls_index': batch[4], + 'p_mask': batch[5]}) outputs = model(**inputs) batch_start_logits, batch_end_logits = outputs[:2] for i, example_index in enumerate(example_indices): - start_logits = batch_start_logits[i].detach().cpu().tolist() - end_logits = batch_end_logits[i].detach().cpu().tolist() eval_feature = features[example_index.item()] unique_id = int(eval_feature.unique_id) - all_results.append(RawResult(unique_id=unique_id, - start_logits=start_logits, - end_logits=end_logits)) + if args.model_type in ['xlnet', 'xlm']: + # XLNet uses a more complex post-processing procedure + result = RawResultExtended(unique_id = unique_id, + start_top_log_probs = to_list(outputs[0][i]), + start_top_index = to_list(outputs[1][i]), + end_top_log_probs = to_list(outputs[2][i]), + end_top_index = to_list(outputs[3][i]), + cls_logits = to_list(outputs[4][i])) + else: + result = RawResult(unique_id = unique_id, + start_logits = to_list(outputs[0][i]), + end_logits = to_list(outputs[1][i])) + all_results.append(result) # Compute predictions output_prediction_file = os.path.join(args.output_dir, "predictions_{}.json".format(prefix)) output_nbest_file = os.path.join(args.output_dir, "nbest_predictions_{}.json".format(prefix)) output_null_log_odds_file = os.path.join(args.output_dir, "null_odds_{}.json".format(prefix)) - write_predictions(examples, features, all_results, args.n_best_size, args.max_answer_length, - args.do_lower_case, output_prediction_file, output_nbest_file, - output_null_log_odds_file, args.verbose_logging, - args.version_2_with_negative, args.null_score_diff_threshold) + + if args.model_type in ['xlnet', 'xlm']: + # XLNet uses a more complex post-processing procedure + write_predictions_extended(examples, features, all_results, args.n_best_size, + args.max_answer_length, output_prediction_file, + output_nbest_file, output_null_log_odds_file, args.predict_file, + args.start_n_top, args.end_n_top, args.version_2_with_negative) + else: + write_predictions(examples, features, all_results, args.n_best_size, + args.max_answer_length, args.do_lower_case, output_prediction_file, + output_nbest_file, output_null_log_odds_file, args.verbose_logging, + args.version_2_with_negative, args.null_score_diff_threshold) # Evaluate with the official SQuAD script evaluate_options = EVAL_OPTS(data_file=args.predict_file, @@ -244,8 +270,8 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal else: logger.info("Creating features from dataset file at %s", input_file) examples = read_squad_examples(input_file=input_file, - is_training=not evaluate, - version_2_with_negative=args.version_2_with_negative) + is_training=not evaluate, + version_2_with_negative=args.version_2_with_negative) features = convert_examples_to_features(examples=examples, tokenizer=tokenizer, max_seq_length=args.max_seq_length, @@ -260,13 +286,18 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long) + all_cls_index = torch.tensor([f.cls_index for f in features], dtype=torch.long) + all_p_mask = torch.tensor([f.p_mask for f in features], dtype=torch.float) if evaluate: all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, + all_example_index, all_cls_index, all_p_mask) else: all_start_positions = torch.tensor([f.start_position for f in features], dtype=torch.long) all_end_positions = torch.tensor([f.end_position for f in features], dtype=torch.long) - dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_start_positions, all_end_positions) + dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, + all_start_positions, all_end_positions, + all_cls_index, all_p_mask) if output_examples: return dataset, examples, features diff --git a/examples/utils_squad.py b/examples/utils_squad.py index 305eeb7b40..d898a0a17e 100644 --- a/examples/utils_squad.py +++ b/examples/utils_squad.py @@ -26,6 +26,9 @@ from io import open from pytorch_transformers.tokenization_bert import BasicTokenizer, whitespace_tokenize +# Required by XLNet evaluation method to compute optimal threshold (see write_predictions_extended() method) +from utils_squad_evaluate import find_all_best_thresh_v2, make_qid_to_has_ans, get_raw_scores + logger = logging.getLogger(__name__) @@ -82,6 +85,8 @@ class InputFeatures(object): input_ids, input_mask, segment_ids, + cls_index, + p_mask, start_position=None, end_position=None, is_impossible=None): @@ -94,6 +99,8 @@ class InputFeatures(object): self.input_ids = input_ids self.input_mask = input_mask self.segment_ids = segment_ids + self.cls_index = cls_index + self.p_mask = p_mask self.start_position = start_position self.end_position = end_position self.is_impossible = is_impossible @@ -178,13 +185,25 @@ def read_squad_examples(input_file, is_training, version_2_with_negative): def convert_examples_to_features(examples, tokenizer, max_seq_length, - doc_stride, max_query_length, is_training): + doc_stride, max_query_length, is_training, + cls_token_at_end=False, + cls_token='[CLS]', sep_token='[SEP]', pad_token=0, + sequence_a_segment_id=0, sequence_b_segment_id=1, + cls_token_segment_id=0, pad_token_segment_id=0, + mask_padding_with_zero=True): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 + # cnt_pos, cnt_neg = 0, 0 + # max_N, max_M = 1024, 1024 + # f = np.zeros((max_N, max_M), dtype=np.float32) features = [] for (example_index, example) in enumerate(examples): + + # if example_index % 100 == 0: + # logger.info('Converting %s/%s pos %s neg %s', example_index, len(examples), cnt_pos, cnt_neg) + query_tokens = tokenizer.tokenize(example.question_text) if len(query_tokens) > max_query_length: @@ -239,14 +258,30 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, token_to_orig_map = {} token_is_max_context = {} segment_ids = [] - tokens.append("[CLS]") - segment_ids.append(0) + + # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer) + # Original TF implem also keep the classification token (set to 0) (not sure why...) + p_mask = [] + + # CLS token at the beginning + if not cls_token_at_end: + tokens.append(cls_token) + segment_ids.append(cls_token_segment_id) + p_mask.append(0) + cls_index = 0 + + # Query for token in query_tokens: tokens.append(token) - segment_ids.append(0) - tokens.append("[SEP]") - segment_ids.append(0) + segment_ids.append(sequence_a_segment_id) + p_mask.append(1) + # SEP token + tokens.append(sep_token) + segment_ids.append(sequence_a_segment_id) + p_mask.append(1) + + # Paragraph for i in range(doc_span.length): split_token_index = doc_span.start + i token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index] @@ -255,29 +290,42 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(1) - tokens.append("[SEP]") - segment_ids.append(1) + segment_ids.append(sequence_b_segment_id) + p_mask.append(0) + + # SEP token + tokens.append(sep_token) + segment_ids.append(sequence_b_segment_id) + p_mask.append(1) + + # CLS token at the end + if cls_token_at_end: + tokens.append(cls_token) + segment_ids.append(cls_token_segment_id) + p_mask.append(0) + cls_index = len(tokens) - 1 # Index of classification token input_ids = tokenizer.convert_tokens_to_ids(tokens) # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. - input_mask = [1] * len(input_ids) + input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) # Zero-pad up to the sequence length. while len(input_ids) < max_seq_length: - input_ids.append(0) - input_mask.append(0) - segment_ids.append(0) + input_ids.append(pad_token) + input_mask.append(0 if mask_padding_with_zero else 1) + segment_ids.append(pad_token_segment_id) + p_mask.append(1) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length assert len(segment_ids) == max_seq_length + span_is_impossible = example.is_impossible start_position = None end_position = None - if is_training and not example.is_impossible: + if is_training and not span_is_impossible: # For training, if our document chunk does not contain an annotation # we throw it out, since there is nothing to predict. doc_start = doc_span.start @@ -289,13 +337,16 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, if out_of_span: start_position = 0 end_position = 0 + span_is_impossible = True else: doc_offset = len(query_tokens) + 2 start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset - if is_training and example.is_impossible: - start_position = 0 - end_position = 0 + + if is_training and span_is_impossible: + start_position = cls_index + end_position = cls_index + if example_index < 20: logger.info("*** Example ***") logger.info("unique_id: %s" % (unique_id)) @@ -312,9 +363,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, "input_mask: %s" % " ".join([str(x) for x in input_mask])) logger.info( "segment_ids: %s" % " ".join([str(x) for x in segment_ids])) - if is_training and example.is_impossible: + if is_training and span_is_impossible: logger.info("impossible example") - if is_training and not example.is_impossible: + if is_training and not span_is_impossible: answer_text = " ".join(tokens[start_position:(end_position + 1)]) logger.info("start_position: %d" % (start_position)) logger.info("end_position: %d" % (end_position)) @@ -332,9 +383,11 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, input_ids=input_ids, input_mask=input_mask, segment_ids=segment_ids, + cls_index=cls_index, + p_mask=p_mask, start_position=start_position, end_position=end_position, - is_impossible=example.is_impossible)) + is_impossible=span_is_impossible)) unique_id += 1 return features @@ -417,7 +470,6 @@ def _check_is_max_context(doc_spans, cur_span_index, position): RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"]) - def write_predictions(all_examples, all_features, all_results, n_best_size, max_answer_length, do_lower_case, output_prediction_file, output_nbest_file, output_null_log_odds_file, verbose_logging, @@ -612,6 +664,182 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, return all_predictions +# For XLNet (and XLM which uses the same head) +RawResultExtended = collections.namedtuple("RawResultExtended", + ["unique_id", "start_top_log_probs", "start_top_index", + "end_top_log_probs", "end_top_index", "cls_logits"]) + + +def write_predictions_extended(all_examples, all_features, all_results, n_best_size, + max_answer_length, output_prediction_file, + output_nbest_file, + output_null_log_odds_file, orig_data, + start_n_top, end_n_top, version_2_with_negative): + """ XLNet write prediction logic (more complex than Bert's). + Write final predictions to the json file and log-odds of null if needed. + + Requires utils_squad_evaluate.py + """ + _PrelimPrediction = collections.namedtuple( # pylint: disable=invalid-name + "PrelimPrediction", + ["feature_index", "start_index", "end_index", + "start_log_prob", "end_log_prob"]) + + _NbestPrediction = collections.namedtuple( # pylint: disable=invalid-name + "NbestPrediction", ["text", "start_log_prob", "end_log_prob"]) + + logger.info("Writing predictions to: %s", output_prediction_file) + # logger.info("Writing nbest to: %s" % (output_nbest_file)) + + example_index_to_features = collections.defaultdict(list) + for feature in all_features: + example_index_to_features[feature.example_index].append(feature) + + unique_id_to_result = {} + for result in all_results: + unique_id_to_result[result.unique_id] = result + + all_predictions = collections.OrderedDict() + all_nbest_json = collections.OrderedDict() + scores_diff_json = collections.OrderedDict() + + for (example_index, example) in enumerate(all_examples): + features = example_index_to_features[example_index] + + prelim_predictions = [] + # keep track of the minimum score of null start+end of position 0 + score_null = 1000000 # large and positive + + for (feature_index, feature) in enumerate(features): + result = unique_id_to_result[feature.unique_id] + + cur_null_score = result.cls_logits + + # if we could have irrelevant answers, get the min score of irrelevant + score_null = min(score_null, cur_null_score) + + for i in range(start_n_top): + for j in range(end_n_top): + start_log_prob = result.start_top_log_probs[i] + start_index = result.start_top_index[i] + + j_index = i * end_n_top + j + + end_log_prob = result.end_top_log_probs[j_index] + end_index = result.end_top_index[j_index] + + # We could hypothetically create invalid predictions, e.g., predict + # that the start of the span is in the question. We throw out all + # invalid predictions. + if start_index >= feature.paragraph_len - 1: + continue + if end_index >= feature.paragraph_len - 1: + continue + + if not feature.token_is_max_context.get(start_index, False): + continue + if end_index < start_index: + continue + length = end_index - start_index + 1 + if length > max_answer_length: + continue + + prelim_predictions.append( + _PrelimPrediction( + feature_index=feature_index, + start_index=start_index, + end_index=end_index, + start_log_prob=start_log_prob, + end_log_prob=end_log_prob)) + + prelim_predictions = sorted( + prelim_predictions, + key=lambda x: (x.start_log_prob + x.end_log_prob), + reverse=True) + + seen_predictions = {} + nbest = [] + for pred in prelim_predictions: + if len(nbest) >= n_best_size: + break + feature = features[pred.feature_index] + + tok_start_to_orig_index = feature.tok_start_to_orig_index + tok_end_to_orig_index = feature.tok_end_to_orig_index + start_orig_pos = tok_start_to_orig_index[pred.start_index] + end_orig_pos = tok_end_to_orig_index[pred.end_index] + + paragraph_text = example.paragraph_text + final_text = paragraph_text[start_orig_pos: end_orig_pos + 1].strip() + + if final_text in seen_predictions: + continue + + seen_predictions[final_text] = True + + nbest.append( + _NbestPrediction( + text=final_text, + start_log_prob=pred.start_log_prob, + end_log_prob=pred.end_log_prob)) + + # In very rare edge cases we could have no valid predictions. So we + # just create a nonce prediction in this case to avoid failure. + if not nbest: + nbest.append( + _NbestPrediction(text="", start_log_prob=-1e6, + end_log_prob=-1e6)) + + total_scores = [] + best_non_null_entry = None + for entry in nbest: + total_scores.append(entry.start_log_prob + entry.end_log_prob) + if not best_non_null_entry: + best_non_null_entry = entry + + probs = _compute_softmax(total_scores) + + nbest_json = [] + for (i, entry) in enumerate(nbest): + output = collections.OrderedDict() + output["text"] = entry.text + output["probability"] = probs[i] + output["start_log_prob"] = entry.start_log_prob + output["end_log_prob"] = entry.end_log_prob + nbest_json.append(output) + + assert len(nbest_json) >= 1 + assert best_non_null_entry is not None + + score_diff = score_null + scores_diff_json[example.qas_id] = score_diff + # note(zhiliny): always predict best_non_null_entry + # and the evaluation script will search for the best threshold + all_predictions[example.qas_id] = best_non_null_entry.text + + all_nbest_json[example.qas_id] = nbest_json + + with open(output_prediction_file, "w") as writer: + writer.write(json.dumps(all_predictions, indent=4) + "\n") + + with open(output_nbest_file, "w") as writer: + writer.write(json.dumps(all_nbest_json, indent=4) + "\n") + + if version_2_with_negative: + with open(output_null_log_odds_file, "w") as writer: + writer.write(json.dumps(scores_diff_json, indent=4) + "\n") + + qid_to_has_ans = make_qid_to_has_ans(orig_data) + has_ans_qids = [k for k, v in qid_to_has_ans.items() if v] + no_ans_qids = [k for k, v in qid_to_has_ans.items() if not v] + exact_raw, f1_raw = get_raw_scores(orig_data, all_predictions) + out_eval = {} + + find_all_best_thresh_v2(out_eval, all_predictions, exact_raw, f1_raw, scores_diff_json, qid_to_has_ans) + + return out_eval + + def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): """Project the tokenized prediction back to the original text.""" diff --git a/examples/utils_squad_evaluate.py b/examples/utils_squad_evaluate.py index d0cf643fe3..ed162e6fe6 100644 --- a/examples/utils_squad_evaluate.py +++ b/examples/utils_squad_evaluate.py @@ -1,4 +1,5 @@ -"""Official evaluation script for SQuAD version 2.0. +""" Official evaluation script for SQuAD version 2.0. + Modified by XLNet authors to update `find_best_threshold` scripts for SQuAD V2.0 In addition to basic functionality, we also compute additional statistics and plot precision-recall curves if an additional na_prob.json file is provided. @@ -232,6 +233,36 @@ def find_best_thresh(preds, scores, na_probs, qid_to_has_ans): best_thresh = na_probs[qid] return 100.0 * best_score / len(scores), best_thresh +def find_best_thresh_v2(preds, scores, na_probs, qid_to_has_ans): + num_no_ans = sum(1 for k in qid_to_has_ans if not qid_to_has_ans[k]) + cur_score = num_no_ans + best_score = cur_score + best_thresh = 0.0 + qid_list = sorted(na_probs, key=lambda k: na_probs[k]) + for i, qid in enumerate(qid_list): + if qid not in scores: continue + if qid_to_has_ans[qid]: + diff = scores[qid] + else: + if preds[qid]: + diff = -1 + else: + diff = 0 + cur_score += diff + if cur_score > best_score: + best_score = cur_score + best_thresh = na_probs[qid] + + has_ans_score, has_ans_cnt = 0, 0 + for qid in qid_list: + if not qid_to_has_ans[qid]: continue + has_ans_cnt += 1 + + if qid not in scores: continue + has_ans_score += scores[qid] + + return 100.0 * best_score / len(scores), best_thresh, 1.0 * has_ans_score / has_ans_cnt + def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): best_exact, exact_thresh = find_best_thresh(preds, exact_raw, na_probs, qid_to_has_ans) best_f1, f1_thresh = find_best_thresh(preds, f1_raw, na_probs, qid_to_has_ans) @@ -240,6 +271,16 @@ def find_all_best_thresh(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_h main_eval['best_f1'] = best_f1 main_eval['best_f1_thresh'] = f1_thresh +def find_all_best_thresh_v2(main_eval, preds, exact_raw, f1_raw, na_probs, qid_to_has_ans): + best_exact, exact_thresh, has_ans_exact = find_best_thresh_v2(preds, exact_raw, na_probs, qid_to_has_ans) + best_f1, f1_thresh, has_ans_f1 = find_best_thresh_v2(preds, f1_raw, na_probs, qid_to_has_ans) + main_eval['best_exact'] = best_exact + main_eval['best_exact_thresh'] = exact_thresh + main_eval['best_f1'] = best_f1 + main_eval['best_f1_thresh'] = f1_thresh + main_eval['has_ans_exact'] = has_ans_exact + main_eval['has_ans_f1'] = has_ans_f1 + def main(OPTS): with open(OPTS.data_file) as f: dataset_json = json.load(f) diff --git a/pytorch_transformers/modeling_utils.py b/pytorch_transformers/modeling_utils.py index 3f21c98b04..ebee4fac1d 100644 --- a/pytorch_transformers/modeling_utils.py +++ b/pytorch_transformers/modeling_utils.py @@ -493,8 +493,9 @@ class PoolerStartLogits(nn.Module): def forward(self, hidden_states, p_mask=None): """ Args: - `p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS) - shape [batch_size, seq_len]. 1.0 means token should be masked. + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape `(batch_size, seq_len)` + invalid position mask such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. """ x = self.dense(hidden_states).squeeze(-1) @@ -516,11 +517,16 @@ class PoolerEndLogits(nn.Module): def forward(self, hidden_states, start_states=None, start_positions=None, p_mask=None): """ Args: - One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states. - `start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - `p_mask`: [optional] invalid position mask such as query and special symbols (PAD, SEP, CLS) - shape [batch_size, seq_len]. 1.0 means token should be masked. + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. + + **start_states**: ``torch.LongTensor`` of shape identical to hidden_states + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span: + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. """ slen, hsz = hidden_states.shape[-2:] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" @@ -549,13 +555,21 @@ class PoolerAnswerClass(nn.Module): self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False) def forward(self, hidden_states, start_states=None, start_positions=None, cls_index=None): - """ Args: - One of start_states, start_positions should be not None. If both are set, start_positions overrides start_states. - `start_states`: hidden states of the first tokens for the labeled span: torch.LongTensor of shape identical to hidden_states. - `start_positions`: position of the first token for the labeled span: torch.LongTensor of shape [batch_size]. - `cls_index`: position of the CLS token: torch.LongTensor of shape [batch_size]. If None, take the last token. + """ + Args: + One of ``start_states``, ``start_positions`` should be not None. + If both are set, ``start_positions`` overrides ``start_states``. - # note(zhiliny): no dependency on end_feature so that we can obtain one single `cls_logits` for each sample + **start_states**: ``torch.LongTensor`` of shape identical to ``hidden_states``. + hidden states of the first tokens for the labeled span. + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + + note(Original repo): + no dependency on end_feature so that we can obtain one single `cls_logits` + for each sample """ slen, hsz = hidden_states.shape[-2:] assert start_states is not None or start_positions is not None, "One of start_states, start_positions should be not None" @@ -577,7 +591,35 @@ class PoolerAnswerClass(nn.Module): class SQuADHead(nn.Module): - """ A SQuAD head inspired by XLNet. + r""" A SQuAD head inspired by XLNet. + + Parameters: + config (:class:`~pytorch_transformers.XLNetConfig`): Model configuration class with all the parameters of the model. + + Inputs: + **hidden_states**: ``torch.FloatTensor`` of shape ``(batch_size, seq_len, hidden_size)`` + hidden states of sequence tokens + **start_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the first token for the labeled span. + **end_positions**: ``torch.LongTensor`` of shape ``(batch_size,)`` + position of the last token for the labeled span. + **cls_index**: torch.LongTensor of shape ``(batch_size,)`` + position of the CLS token. If None, take the last token. + **is_impossible**: ``torch.LongTensor`` of shape ``(batch_size,)`` + Whether the question has a possible answer in the paragraph or not. + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, seq_len)`` + Mask of invalid position such as query and special symbols (PAD, SEP, CLS) + 1.0 means token should be masked. + + Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: + **loss**: (`optional`, returned if both ``start_positions`` and ``end_positions`` are provided) ``torch.FloatTensor`` of shape ``(1,)``: + Classification loss as the sum of start token, end token (and is_impossible if provided) classification losses. + **last_hidden_state**: `(`optional`, returned if ``start_positions`` or ``end_positions`` is not provided) `torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)`` + Sequence of hidden-states at the last layer of the model. + **mems**: + list of ``torch.FloatTensor`` (one for each layer): + that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model + (see `mems` input above). Can be used to speed up sequential decoding and attend to longer context. """ def __init__(self, config): super(SQuADHead, self).__init__() @@ -590,8 +632,6 @@ class SQuADHead(nn.Module): def forward(self, hidden_states, start_positions=None, end_positions=None, cls_index=None, is_impossible=None, p_mask=None): - """ hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer. - """ outputs = () start_logits = self.start_logits(hidden_states, p_mask=p_mask) @@ -618,9 +658,8 @@ class SQuADHead(nn.Module): # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss total_loss += cls_loss * 0.5 - outputs = (total_loss, start_logits, end_logits, cls_logits) + outputs - else: - outputs = (total_loss, start_logits, end_logits) + outputs + + outputs = (total_loss,) + outputs else: # during inference, compute the end logits based on beam search @@ -647,7 +686,7 @@ class SQuADHead(nn.Module): outputs = (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits) + outputs # return start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits - # or (if labels are provided) total_loss, start_logits, end_logits, (cls_logits) + # or (if labels are provided) (total_loss,) return outputs diff --git a/pytorch_transformers/modeling_xlnet.py b/pytorch_transformers/modeling_xlnet.py index 5e576c51c1..6de4d02103 100644 --- a/pytorch_transformers/modeling_xlnet.py +++ b/pytorch_transformers/modeling_xlnet.py @@ -1162,8 +1162,9 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel): Labels whether a question has an answer or no answer (SQuAD 2.0) **cls_index**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``: Labels for position (index) of the classification token to use as input for computing plausibility of the answer. - **p_mask**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``: - Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...) + **p_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``: + Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). + 1.0 means token should be masked. 0.0 mean token is not masked. Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs: **loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``: