From 74c5035808de18b016c88b1f864d609bc684b367 Mon Sep 17 00:00:00 2001 From: hlums Date: Mon, 14 Oct 2019 21:27:11 +0000 Subject: [PATCH] Fix token order in xlnet preprocessing. --- examples/run_squad.py | 6 +++++- examples/utils_squad.py | 41 ++++++++++++++++++++++++++++++----------- 2 files changed, 35 insertions(+), 12 deletions(-) diff --git a/examples/run_squad.py b/examples/run_squad.py index 43b65d2c3c..a746d441df 100644 --- a/examples/run_squad.py +++ b/examples/run_squad.py @@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal max_seq_length=args.max_seq_length, doc_stride=args.doc_stride, max_query_length=args.max_query_length, - is_training=not evaluate) + is_training=not evaluate, + cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0, + pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0, + cls_token_at_end=True if args.model_type in ['xlnet'] else False, + sequence_a_is_doc=True if args.model_type in ['xlnet'] else False) if args.local_rank in [-1, 0]: logger.info("Saving features into cached file %s", cached_features_file) torch.save(features, cached_features_file) diff --git a/examples/utils_squad.py b/examples/utils_squad.py index b990ecc842..6d1c86493d 100644 --- a/examples/utils_squad.py +++ b/examples/utils_squad.py @@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, cls_token='[CLS]', sep_token='[SEP]', pad_token=0, sequence_a_segment_id=0, sequence_b_segment_id=1, cls_token_segment_id=0, pad_token_segment_id=0, - mask_padding_with_zero=True): + mask_padding_with_zero=True, + sequence_a_is_doc=False): """Loads a data file into a list of `InputBatch`s.""" unique_id = 1000000000 @@ -272,17 +273,19 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, p_mask.append(0) cls_index = 0 - # Query - for token in query_tokens: - tokens.append(token) + # XLNet: P SEP Q SEP CLS + # Others: CLS Q SEP P SEP + if not sequence_a_is_doc: + # Query + tokens += query_tokens + segment_ids += [sequence_a_segment_id] * len(query_tokens) + p_mask += [1] * len(query_tokens) + + # SEP token + tokens.append(sep_token) segment_ids.append(sequence_a_segment_id) p_mask.append(1) - # SEP token - tokens.append(sep_token) - segment_ids.append(sequence_a_segment_id) - p_mask.append(1) - # Paragraph for i in range(doc_span.length): split_token_index = doc_span.start + i @@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, split_token_index) token_is_max_context[len(tokens)] = is_max_context tokens.append(all_doc_tokens[split_token_index]) - segment_ids.append(sequence_b_segment_id) + if not sequence_a_is_doc: + segment_ids.append(sequence_b_segment_id) + else: + segment_ids.append(sequence_a_segment_id) p_mask.append(0) paragraph_len = doc_span.length + if sequence_a_is_doc: + # SEP token + tokens.append(sep_token) + segment_ids.append(sequence_a_segment_id) + p_mask.append(1) + + tokens += query_tokens + segment_ids += [sequence_b_segment_id] * len(query_tokens) + p_mask += [1] * len(query_tokens) + # SEP token tokens.append(sep_token) segment_ids.append(sequence_b_segment_id) @@ -342,7 +358,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, end_position = 0 span_is_impossible = True else: - doc_offset = len(query_tokens) + 2 + if sequence_a_is_doc: + doc_offset = 0 + else: + doc_offset = len(query_tokens) + 2 start_position = tok_start_position - doc_start + doc_offset end_position = tok_end_position - doc_start + doc_offset