diff --git a/examples/run_swag.py b/examples/run_swag.py index 5a92f811b4..06169a3e9b 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -70,20 +70,13 @@ class SwagExample(object): class InputFeatures(object): def __init__(self, - unique_id, example_id, - input_ids, - input_mask, - segment_ids, - label_id + choices_features, + label ): - self.unique_id = unique_id self.example_id = example_id - self.input_ids = input_ids - self.input_mask = input_mask - self.segment_ids = segment_ids - self.label_id = label_id - + self.choices_features = choices_features + self.label = label def read_swag_examples(input_file, is_training): input_df = pd.read_csv(input_file) @@ -145,7 +138,7 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, # place so that the total length is less than the # specified length. Account for [CLS], [SEP], [SEP] with # "- 3" - _truncate_seq_pair(context_tokens, ending_tokens, max_seq_length - 3) + _truncate_seq_pair(context_tokens_choice, ending_tokens, max_seq_length - 3) tokens = ["[CLS]"] + context_tokens_choice + ["[SEP]"] + ending_tokens + ["[SEP]"] segment_ids = [0] * (len(context_tokens_choice) + 2) + [1] * (len(ending_tokens) + 1) @@ -178,7 +171,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length, if is_training: logger.info(f"label: {label}") + features.append( + InputFeatures( + example_id = example.swag_id, + choices_features = choices_features, + label = label + ) + ) + return features def _truncate_seq_pair(tokens_a, tokens_b, max_length): """Truncates a sequence pair in place to the maximum length.""" @@ -206,4 +207,4 @@ if __name__ == "__main__": print("###########################") print(example) tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") - convert_examples_to_features(examples, tokenizer, max_seq_length, is_training) + features = convert_examples_to_features(examples, tokenizer, max_seq_length, is_training)