Fix token order in xlnet preprocessing.
This commit is contained in:
@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
||||
max_seq_length=args.max_seq_length,
|
||||
doc_stride=args.doc_stride,
|
||||
max_query_length=args.max_query_length,
|
||||
is_training=not evaluate)
|
||||
is_training=not evaluate,
|
||||
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||
if args.local_rank in [-1, 0]:
|
||||
logger.info("Saving features into cached file %s", cached_features_file)
|
||||
torch.save(features, cached_features_file)
|
||||
|
||||
@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
mask_padding_with_zero=True,
|
||||
sequence_a_is_doc=False):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
@@ -272,11 +273,13 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
p_mask.append(0)
|
||||
cls_index = 0
|
||||
|
||||
# XLNet: P SEP Q SEP CLS
|
||||
# Others: CLS Q SEP P SEP
|
||||
if not sequence_a_is_doc:
|
||||
# Query
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
tokens += query_tokens
|
||||
segment_ids += [sequence_a_segment_id] * len(query_tokens)
|
||||
p_mask += [1] * len(query_tokens)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
if not sequence_a_is_doc:
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
else:
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(0)
|
||||
paragraph_len = doc_span.length
|
||||
|
||||
if sequence_a_is_doc:
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
tokens += query_tokens
|
||||
segment_ids += [sequence_b_segment_id] * len(query_tokens)
|
||||
p_mask += [1] * len(query_tokens)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
@@ -341,6 +357,9 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
start_position = 0
|
||||
end_position = 0
|
||||
span_is_impossible = True
|
||||
else:
|
||||
if sequence_a_is_doc:
|
||||
doc_offset = 0
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
|
||||
Reference in New Issue
Block a user