From f2b873e995e36732e41f2484b990b8109c239cd6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Thu, 6 Dec 2018 15:40:47 +0100 Subject: [PATCH] convert_examples_to_features code and small improvements. --- examples/run_swag.py | 154 ++++++++++++++++++++++++++++++++++++++----- 1 file changed, 138 insertions(+), 16 deletions(-) diff --git a/examples/run_swag.py b/examples/run_swag.py index 9fa5bad050..5a92f811b4 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -16,6 +16,15 @@ import pandas as pd +import logging + +from pytorch_pretrained_bert.tokenization import BertTokenizer + +logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', + datefmt = '%m/%d/%Y %H:%M:%S', + level = logging.INFO) +logger = logging.getLogger(__name__) + class SwagExample(object): """A single training/test example for the SWAG dataset.""" @@ -31,10 +40,12 @@ class SwagExample(object): self.swag_id = swag_id self.context_sentence = context_sentence self.start_ending = start_ending - self.ending_0 = ending_0 - self.ending_1 = ending_1 - self.ending_2 = ending_2 - self.ending_3 = ending_3 + self.endings = [ + ending_0, + ending_1, + ending_2, + ending_3, + ] self.label = label def __str__(self): @@ -42,19 +53,37 @@ class SwagExample(object): def __repr__(self): l = [ - f'swag_id: {self.swag_id}', - f'context_sentence: {self.context_sentence}', - f'start_ending: {self.start_ending}', - f'ending_0: {self.ending_0}', - f'ending_1: {self.ending_1}', - f'ending_2: {self.ending_2}', - f'ending_3: {self.ending_3}', + f"swag_id: {self.swag_id}", + f"context_sentence: {self.context_sentence}", + f"start_ending: {self.start_ending}", + f"ending_0: {self.endings[0]}", + f"ending_1: {self.endings[1]}", + f"ending_2: {self.endings[2]}", + f"ending_3: {self.endings[3]}", ] if self.label is not None: - l.append(f'label: {self.label}') + l.append(f"label: {self.label}") + + return ", ".join(l) + + +class InputFeatures(object): + def __init__(self, + unique_id, + example_id, + input_ids, + input_mask, + segment_ids, + label_id + ): + self.unique_id = unique_id + self.example_id = example_id + self.input_ids = input_ids + self.input_mask = input_mask + self.segment_ids = segment_ids + self.label_id = label_id - return ', '.join(l) def read_swag_examples(input_file, is_training): input_df = pd.read_csv(input_file) @@ -67,7 +96,9 @@ def read_swag_examples(input_file, is_training): SwagExample( swag_id = row['fold-ind'], context_sentence = row['sent1'], - start_ending = row['sent2'], + start_ending = row['sent2'], # in the swag dataset, the + # common beginning of each + # choice is stored in "sent2". ending_0 = row['ending0'], ending_1 = row['ending1'], ending_2 = row['ending2'], @@ -79,9 +110,100 @@ def read_swag_examples(input_file, is_training): return examples +def convert_examples_to_features(examples, tokenizer, max_seq_length, + is_training): + """Loads a data file into a list of `InputBatch`s.""" + + # Swag is a multiple choice task. To perform this task using Bert, + # we will use the formatting proposed in "Improving Language + # Understanding by Generative Pre-Training" and suggested by + # @jacobdevlin-google in this issue + # https://github.com/google-research/bert/issues/38. + # + # Each choice will correspond to a sample on which we run the + # inference. For a given Swag example, we will create the 4 + # following inputs: + # - [CLS] context [SEP] choice_1 [SEP] + # - [CLS] context [SEP] choice_2 [SEP] + # - [CLS] context [SEP] choice_3 [SEP] + # - [CLS] context [SEP] choice_4 [SEP] + # The model will output a single value for each input. To get the + # final decision of the model, we will run a softmax over these 4 + # outputs. + features = [] + for example_index, example in enumerate(examples): + context_tokens = tokenizer.tokenize(example.context_sentence) + start_ending_tokens = tokenizer.tokenize(example.start_ending) + + choices_features = [] + for ending_index, ending in enumerate(example.endings): + # We create a copy of the context tokens in order to be + # able to shrink it according to ending_tokens + context_tokens_choice = context_tokens[:] + ending_tokens = start_ending_tokens + tokenizer.tokenize(ending) + # Modifies `context_tokens_choice` and `ending_tokens` in + # place so that the total length is less than the + # specified length. Account for [CLS], [SEP], [SEP] with + # "- 3" + _truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3) + + tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] + segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) + + input_ids = tokenizer.convert_tokens_to_ids(tokens) + input_mask = [1] * len(input_ids) + + # Zero-pad up to the sequence length. + padding = [0] * (max_seq_length - len(input_ids)) + input_ids += padding + input_mask += padding + segment_ids += padding + + assert len(input_ids) == max_seq_length + assert len(input_mask) == max_seq_length + assert len(segment_ids) == max_seq_length + + choices_features.append((tokens, input_ids, input_mask, segment_ids)) + + label = example.label + if example_index < 5: + logger.info("*** Example ***") + logger.info(f"swag_id: {example.swag_id}") + for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features): + logger.info(f"choice: {choice_idx}") + logger.info(f"tokens: {' '.join(tokens)}") + logger.info(f"input_ids: {' '.join(map(str, input_ids))}") + logger.info(f"input_mask: {' '.join(map(str, input_mask))}") + logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}") + if is_training: + logger.info(f"label: {label}") + + + +def _truncate_seq_pair(tokens_a, tokens_b, max_length): + """Truncates a sequence pair in place to the maximum length.""" + + # This is a simple heuristic which will always truncate the longer sequence + # one token at a time. This makes more sense than truncating an equal percent + # of tokens from each, since if one sequence is very short then each token + # that's truncated likely contains more information than a longer sequence. + while True: + total_length = len(tokens_a) + len(tokens_b) + if total_length <= max_length: + break + if len(tokens_a) > len(tokens_b): + tokens_a.pop() + else: + tokens_b.pop() + + if __name__ == "__main__": - examples = read_swag_examples('data/train.csv', True) + is_training = True + max_seq_length = 80 + examples = read_swag_examples('data/train.csv', is_training) print(len(examples)) for example in examples[:5]: - print('###########################') + print("###########################") print(example) + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)