pad on left for xlnet

This commit is contained in:
thomwolf
2019-06-24 15:05:11 +02:00
parent c888663f18
commit 7334bf6c21
2 changed files with 29 additions and 15 deletions

View File

@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
def convert_examples_to_features(examples, label_list, max_seq_length,
tokenizer, output_mode,
cls_token_at_end=False, cls_token='[CLS]',
sep_token='[SEP]', cls_token_segment_id=0):
cls_token_at_end=False, pad_on_left=False,
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
sequence_a_segment_id=0, sequence_b_segment_id=1,
cls_token_segment_id=1, pad_token_segment_id=0,
mask_padding_with_zero=True):
""" Loads a data file into a list of `InputBatch`s
`cls_token_at_end` define the location of the CLS token:
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# used as as the "sentence vector". Note that this only makes sense because
# the entire model is fine-tuned.
tokens = tokens_a + [sep_token]
segment_ids = [0] * len(tokens)
segment_ids = [sequence_a_segment_id] * len(tokens)
if tokens_b:
tokens += tokens_b + [sep_token]
segment_ids += [1] * (len(tokens_b) + 1)
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
if cls_token_at_end:
tokens = tokens + [cls_token]
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
padding_length = max_seq_length - len(input_ids)
if pad_on_left:
input_ids = ([pad_token] * padding_length) + input_ids
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
else:
input_ids = input_ids + ([pad_token] * padding_length)
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length