From 7334bf6c21c65b6090e16247e294b737581c6d2e Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 24 Jun 2019 15:05:11 +0200 Subject: [PATCH] pad on left for xlnet --- examples/run_xlnet_classifier.py | 18 ++++++++++++------ examples/utils_glue.py | 26 +++++++++++++++++--------- 2 files changed, 29 insertions(+), 15 deletions(-) diff --git a/examples/run_xlnet_classifier.py b/examples/run_xlnet_classifier.py index 6733a25573..514776b242 100644 --- a/examples/run_xlnet_classifier.py +++ b/examples/run_xlnet_classifier.py @@ -198,14 +198,17 @@ def main(): list(filter(None, args.xlnet_model.split('/'))).pop(), str(args.max_seq_length), str(task_name))) - try: + if os.path.exists(cached_train_features_file): + logger.info("Loading train features for cache file %s", cached_train_features_file) with open(cached_train_features_file, "rb") as reader: train_features = pickle.load(reader) - except: + else: + logger.info("No cache file at %s, preparing train features", cached_train_features_file) train_features = convert_examples_to_features( train_examples, label_list, args.max_seq_length, tokenizer, output_mode, cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, - sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2) + sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, + pad_on_left=True, pad_token_segment_id=4) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving train features into cached file %s", cached_train_features_file) with open(cached_train_features_file, "wb") as writer: @@ -344,14 +347,17 @@ def main(): list(filter(None, args.xlnet_model.split('/'))).pop(), str(args.max_seq_length), str(task_name))) - try: + if os.path.exists(cached_eval_features_file): + logger.info("Loading eval features for cache file %s", cached_eval_features_file) with open(cached_eval_features_file, "rb") as reader: eval_features = pickle.load(reader) - except: + else: + logger.info("No cache file at %s, preparing eval features", cached_eval_features_file) eval_features = convert_examples_to_features( eval_examples, label_list, args.max_seq_length, tokenizer, output_mode, cls_token_at_end=True, cls_token=tokenizer.CLS_TOKEN, - sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2) + sep_token=tokenizer.SEP_TOKEN, cls_token_segment_id=2, + pad_on_left=True, pad_token_segment_id=4) if args.local_rank == -1 or torch.distributed.get_rank() == 0: logger.info(" Saving eval features into cached file %s", cached_eval_features_file) with open(cached_eval_features_file, "wb") as writer: diff --git a/examples/utils_glue.py b/examples/utils_glue.py index ed3cde5a93..5d3454f439 100644 --- a/examples/utils_glue.py +++ b/examples/utils_glue.py @@ -389,8 +389,11 @@ class WnliProcessor(DataProcessor): def convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, output_mode, - cls_token_at_end=False, cls_token='[CLS]', - sep_token='[SEP]', cls_token_segment_id=0): + cls_token_at_end=False, pad_on_left=False, + cls_token='[CLS]', sep_token='[SEP]', pad_token=0, + sequence_a_segment_id=0, sequence_b_segment_id=1, + cls_token_segment_id=1, pad_token_segment_id=0, + mask_padding_with_zero=True): """ Loads a data file into a list of `InputBatch`s `cls_token_at_end` define the location of the CLS token: - False (BERT pattern): [CLS] + A + [SEP] + B + [SEP] @@ -438,11 +441,11 @@ def convert_examples_to_features(examples, label_list, max_seq_length, # used as as the "sentence vector". Note that this only makes sense because # the entire model is fine-tuned. tokens = tokens_a + [sep_token] - segment_ids = [0] * len(tokens) + segment_ids = [sequence_a_segment_id] * len(tokens) if tokens_b: tokens += tokens_b + [sep_token] - segment_ids += [1] * (len(tokens_b) + 1) + segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1) if cls_token_at_end: tokens = tokens + [cls_token] @@ -455,13 +458,18 @@ def convert_examples_to_features(examples, label_list, max_seq_length, # The mask has 1 for real tokens and 0 for padding tokens. Only real # tokens are attended to. - input_mask = [1] * len(input_ids) + input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids) # Zero-pad up to the sequence length. - padding = [0] * (max_seq_length - len(input_ids)) - input_ids += padding - input_mask += padding - segment_ids += padding + padding_length = max_seq_length - len(input_ids) + if pad_on_left: + input_ids = ([pad_token] * padding_length) + input_ids + input_mask = ([0 if mask_padding_with_zero else 1] * padding_length) + input_mask + segment_ids = ([pad_token_segment_id] * padding_length) + segment_ids + else: + input_ids = input_ids + ([pad_token] * padding_length) + input_mask = input_mask + ([0 if mask_padding_with_zero else 1] * padding_length) + segment_ids = segment_ids + ([pad_token_segment_id] * padding_length) assert len(input_ids) == max_seq_length assert len(input_mask) == max_seq_length