Fix token order in xlnet preprocessing.
This commit is contained in:
@@ -192,7 +192,8 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||
cls_token_segment_id=0, pad_token_segment_id=0,
|
||||
mask_padding_with_zero=True):
|
||||
mask_padding_with_zero=True,
|
||||
sequence_a_is_doc=False):
|
||||
"""Loads a data file into a list of `InputBatch`s."""
|
||||
|
||||
unique_id = 1000000000
|
||||
@@ -272,17 +273,19 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
p_mask.append(0)
|
||||
cls_index = 0
|
||||
|
||||
# Query
|
||||
for token in query_tokens:
|
||||
tokens.append(token)
|
||||
# XLNet: P SEP Q SEP CLS
|
||||
# Others: CLS Q SEP P SEP
|
||||
if not sequence_a_is_doc:
|
||||
# Query
|
||||
tokens += query_tokens
|
||||
segment_ids += [sequence_a_segment_id] * len(query_tokens)
|
||||
p_mask += [1] * len(query_tokens)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
# Paragraph
|
||||
for i in range(doc_span.length):
|
||||
split_token_index = doc_span.start + i
|
||||
@@ -292,10 +295,23 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
split_token_index)
|
||||
token_is_max_context[len(tokens)] = is_max_context
|
||||
tokens.append(all_doc_tokens[split_token_index])
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
if not sequence_a_is_doc:
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
else:
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(0)
|
||||
paragraph_len = doc_span.length
|
||||
|
||||
if sequence_a_is_doc:
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_a_segment_id)
|
||||
p_mask.append(1)
|
||||
|
||||
tokens += query_tokens
|
||||
segment_ids += [sequence_b_segment_id] * len(query_tokens)
|
||||
p_mask += [1] * len(query_tokens)
|
||||
|
||||
# SEP token
|
||||
tokens.append(sep_token)
|
||||
segment_ids.append(sequence_b_segment_id)
|
||||
@@ -342,7 +358,10 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
end_position = 0
|
||||
span_is_impossible = True
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
if sequence_a_is_doc:
|
||||
doc_offset = 0
|
||||
else:
|
||||
doc_offset = len(query_tokens) + 2
|
||||
start_position = tok_start_position - doc_start + doc_offset
|
||||
end_position = tok_end_position - doc_start + doc_offset
|
||||
|
||||
|
||||
Reference in New Issue
Block a user