pad on left for xlnet
This commit is contained in:
@@ -198,14 +198,17 @@ def main():
|
|||||||
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task_name)))
|
str(task_name)))
|
||||||
try:
|
if os.path.exists(cached_train_features_file):
|
||||||
|
logger.info("Loading train features for cache file %s", cached_train_features_file)
|
||||||
with open(cached_train_features_file, "rb") as reader:
|
with open(cached_train_features_file, "rb") as reader:
|
||||||
train_features = pickle.load(reader)
|
train_features = pickle.load(reader)
|
||||||
except:
|
else:
|
||||||
|
logger.info("No cache file at %s, preparing train features", cached_train_features_file)
|
||||||
train_features = convert_examples_to_features(
|
train_features = convert_examples_to_features(
|
||||||
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
train_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
|
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||||
|
pad_on_left=True, pad_token_segment_id=4)
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
logger.info(" Saving train features into cached file %s", cached_train_features_file)
|
||||||
with open(cached_train_features_file, "wb") as writer:
|
with open(cached_train_features_file, "wb") as writer:
|
||||||
@@ -344,14 +347,17 @@ def main():
|
|||||||
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
list(filter(None, args.xlnet_model.split('/'))).pop(),
|
||||||
str(args.max_seq_length),
|
str(args.max_seq_length),
|
||||||
str(task_name)))
|
str(task_name)))
|
||||||
try:
|
if os.path.exists(cached_eval_features_file):
|
||||||
|
logger.info("Loading eval features for cache file %s", cached_eval_features_file)
|
||||||
with open(cached_eval_features_file, "rb") as reader:
|
with open(cached_eval_features_file, "rb") as reader:
|
||||||
eval_features = pickle.load(reader)
|
eval_features = pickle.load(reader)
|
||||||
except:
|
else:
|
||||||
|
logger.info("No cache file at %s, preparing eval features", cached_eval_features_file)
|
||||||
eval_features = convert_examples_to_features(
|
eval_features = convert_examples_to_features(
|
||||||
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
eval_examples, label_list, args.max_seq_length, tokenizer, output_mode,
|
||||||
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN,
|
||||||
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2)
|
sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2,
|
||||||
|
pad_on_left=True, pad_token_segment_id=4)
|
||||||
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
if args.local_rank == -1 or torch.distributed.get_rank() == 0:
|
||||||
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
logger.info(" Saving eval features into cached file %s", cached_eval_features_file)
|
||||||
with open(cached_eval_features_file, "wb") as writer:
|
with open(cached_eval_features_file, "wb") as writer:
|
||||||
|
|||||||
@@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor):
|
|||||||
|
|
||||||
def convert_examples_to_features(examples, label_list, max_seq_length,
|
def convert_examples_to_features(examples, label_list, max_seq_length,
|
||||||
tokenizer, output_mode,
|
tokenizer, output_mode,
|
||||||
cls_token_at_end=False, cls_token='[CLS]',
|
cls_token_at_end=False, pad_on_left=False,
|
||||||
sep_token='[SEP]', cls_token_segment_id=0):
|
cls_token='[CLS]', sep_token='[SEP]', pad_token=0,
|
||||||
|
sequence_a_segment_id=0, sequence_b_segment_id=1,
|
||||||
|
cls_token_segment_id=1, pad_token_segment_id=0,
|
||||||
|
mask_padding_with_zero=True):
|
||||||
""" Loads a data file into a list of `InputBatch`s
|
""" Loads a data file into a list of `InputBatch`s
|
||||||
`cls_token_at_end` define the location of the CLS token:
|
`cls_token_at_end` define the location of the CLS token:
|
||||||
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
- False (BERT pattern): [CLS] + A + [SEP] + B + [SEP]
|
||||||
@@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
# used as as the "sentence vector". Note that this only makes sense because
|
# used as as the "sentence vector". Note that this only makes sense because
|
||||||
# the entire model is fine-tuned.
|
# the entire model is fine-tuned.
|
||||||
tokens = tokens_a + [sep_token]
|
tokens = tokens_a + [sep_token]
|
||||||
segment_ids = [0] * len(tokens)
|
segment_ids = [sequence_a_segment_id] * len(tokens)
|
||||||
|
|
||||||
if tokens_b:
|
if tokens_b:
|
||||||
tokens += tokens_b + [sep_token]
|
tokens += tokens_b + [sep_token]
|
||||||
segment_ids += [1] * (len(tokens_b) + 1)
|
segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)
|
||||||
|
|
||||||
if cls_token_at_end:
|
if cls_token_at_end:
|
||||||
tokens = tokens + [cls_token]
|
tokens = tokens + [cls_token]
|
||||||
@@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length,
|
|||||||
|
|
||||||
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
||||||
# tokens are attended to.
|
# tokens are attended to.
|
||||||
input_mask = [1] * len(input_ids)
|
input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
||||||
|
|
||||||
# Zero-pad up to the sequence length.
|
# Zero-pad up to the sequence length.
|
||||||
padding = [0] * (max_seq_length - len(input_ids))
|
padding_length = max_seq_length - len(input_ids)
|
||||||
input_ids += padding
|
if pad_on_left:
|
||||||
input_mask += padding
|
input_ids = ([pad_token] * padding_length) + input_ids
|
||||||
segment_ids += padding
|
input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask
|
||||||
|
segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids
|
||||||
|
else:
|
||||||
|
input_ids = input_ids + ([pad_token] * padding_length)
|
||||||
|
input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
||||||
|
segment_ids = segment_ids + ([pad_token_segment_id] * padding_length)
|
||||||
|
|
||||||
assert len(input_ids) == max_seq_length
|
assert len(input_ids) == max_seq_length
|
||||||
assert len(input_mask) == max_seq_length
|
assert len(input_mask) == max_seq_length
|
||||||
|
|||||||
Reference in New Issue
Block a user