From c45d8ac55439decd059d697e21daf27e85ac3412 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Thu, 6 Dec 2018 16:01:28 +0100 Subject: [PATCH] Storing the feature of each choice as a dict for readability. --- examples/run_swag.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/examples/run_swag.py b/examples/run_swag.py index 06169a3e9b..f8494f3a1f 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -73,9 +73,17 @@ class InputFeatures(object): example_id, choices_features, label + ): self.example_id = example_id - self.choices_features = choices_features + self.choices_features = [ + { + 'input_ids': input_ids, + 'input_mask': input_mask, + 'segment_ids': segment_ids + } + for _, input_ids, input_mask, segment_ids in choices_features + ] self.label = label def read_swag_examples(input_file, is_training): @@ -181,6 +189,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, return features + def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" @@ -207,4 +216,11 @@ if __name__ == "__main__": print("###########################") print(example) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training) + features = convert_examples_to_features(examples[:500], tokenizer, max_seq_length, is_training) + for i in range(10): + choice_feature_list = features[i].choices_features + for choice_idx, choice_feature in enumerate(choice_feature_list): + print(f'choice_idx: {choice_idx}') + print(f'input_ids: {" ".join(map(str, choice_feature["input_ids"]))}') + print(f'input_mask: {" ".join(map(str, choice_feature["input_mask"]))}') + print(f'segment_ids: {" ".join(map(str, choice_feature["segment_ids"]))}')