Fix token order in xlnet preprocessing.
This commit is contained in:
@@ -302,7 +302,11 @@ def load_and_cache_examples(args, tokenizer, evaluate=False, output_examples=Fal
|
|||||||
max_seq_length=args.max_seq_length,
|
max_seq_length=args.max_seq_length,
|
||||||
doc_stride=args.doc_stride,
|
doc_stride=args.doc_stride,
|
||||||
max_query_length=args.max_query_length,
|
max_query_length=args.max_query_length,
|
||||||
is_training=not evaluate)
|
is_training=not evaluate,
|
||||||
|
cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
|
||||||
|
pad_token_segment_id=3 if args.model_type in ['xlnet'] else 0,
|
||||||
|
cls_token_at_end=True if args.model_type in ['xlnet'] else False,
|
||||||
|
sequence_a_is_doc=True if args.model_type in ['xlnet'] else False)
|
||||||
if args.local_rank in [-1, 0]:
|
if args.local_rank in [-1, 0]:
|
||||||
logger.info("Saving features into cached file %s", cached_features_file)
|
logger.info("Saving features into cached file %s", cached_features_file)
|
||||||
torch.save(features, cached_features_file)
|
torch.save(features, cached_features_file)
|
||||||
|
|||||||
@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||||
mask_padding_with_zero=True):
|
mask_padding_with_zero=True,
|
||||||
|
sequence_a_is_doc=False):
|
||||||
"""Loads a data file into a list of `InputBatch`s."""
|
"""Loads a data file into a list of `InputBatch`s."""
|
||||||
|
|
||||||
unique_id = 1000000000
|
unique_id = 1000000000
|
||||||
@@ -272,17 +273,19 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
p_mask.append(0)
|
p_mask.append(0)
|
||||||
cls_index = 0
|
cls_index = 0
|
||||||
|
|
||||||
# Query
|
# XLNet: P SEP Q SEP CLS
|
||||||
for token in query_tokens:
|
# Others: CLS Q SEP P SEP
|
||||||
tokens.append(token)
|
if not sequence_a_is_doc:
|
||||||
|
# Query
|
||||||
|
tokens += query_tokens
|
||||||
|
segment_ids += [sequence_a_segment_id] * len(query_tokens)
|
||||||
|
p_mask += [1] * len(query_tokens)
|
||||||
|
|
||||||
|
# SEP token
|
||||||
|
tokens.append(sep_token)
|
||||||
segment_ids.append(sequence_a_segment_id)
|
segment_ids.append(sequence_a_segment_id)
|
||||||
p_mask.append(1)
|
p_mask.append(1)
|
||||||
|
|
||||||
# SEP token
|
|
||||||
tokens.append(sep_token)
|
|
||||||
segment_ids.append(sequence_a_segment_id)
|
|
||||||
p_mask.append(1)
|
|
||||||
|
|
||||||
# Paragraph
|
# Paragraph
|
||||||
for i in range(doc_span.length):
|
for i in range(doc_span.length):
|
||||||
split_token_index = doc_span.start + i
|
split_token_index = doc_span.start + i
|
||||||
@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
split_token_index)
|
split_token_index)
|
||||||
token_is_max_context[len(tokens)] = is_max_context
|
token_is_max_context[len(tokens)] = is_max_context
|
||||||
tokens.append(all_doc_tokens[split_token_index])
|
tokens.append(all_doc_tokens[split_token_index])
|
||||||
segment_ids.append(sequence_b_segment_id)
|
if not sequence_a_is_doc:
|
||||||
|
segment_ids.append(sequence_b_segment_id)
|
||||||
|
else:
|
||||||
|
segment_ids.append(sequence_a_segment_id)
|
||||||
p_mask.append(0)
|
p_mask.append(0)
|
||||||
paragraph_len = doc_span.length
|
paragraph_len = doc_span.length
|
||||||
|
|
||||||
|
if sequence_a_is_doc:
|
||||||
|
# SEP token
|
||||||
|
tokens.append(sep_token)
|
||||||
|
segment_ids.append(sequence_a_segment_id)
|
||||||
|
p_mask.append(1)
|
||||||
|
|
||||||
|
tokens += query_tokens
|
||||||
|
segment_ids += [sequence_b_segment_id] * len(query_tokens)
|
||||||
|
p_mask += [1] * len(query_tokens)
|
||||||
|
|
||||||
# SEP token
|
# SEP token
|
||||||
tokens.append(sep_token)
|
tokens.append(sep_token)
|
||||||
segment_ids.append(sequence_b_segment_id)
|
segment_ids.append(sequence_b_segment_id)
|
||||||
@@ -342,7 +358,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
end_position = 0
|
end_position = 0
|
||||||
span_is_impossible = True
|
span_is_impossible = True
|
||||||
else:
|
else:
|
||||||
doc_offset = len(query_tokens) + 2
|
if sequence_a_is_doc:
|
||||||
|
doc_offset = 0
|
||||||
|
else:
|
||||||
|
doc_offset = len(query_tokens) + 2
|
||||||
start_position = tok_start_position - doc_start + doc_offset
|
start_position = tok_start_position - doc_start + doc_offset
|
||||||
end_position = tok_end_position - doc_start + doc_offset
|
end_position = tok_end_position - doc_start + doc_offset
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user