diff --git a/examples/run_swag.py b/examples/run_swag.py index 37212b4e86..9fa5bad050 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -14,6 +14,9 @@ # limitations under the License. """BERT finetuning runner.""" +import pandas as pd + + class SwagExample(object): """A single training/test example for the SWAG dataset.""" def __init__(self, @@ -53,26 +56,32 @@ class SwagExample(object): return ', '.join(l) -if __name__ == "__main__": - e = SwagExample( - 3416, - 'Members of the procession walk down the street holding small horn brass instruments.', - 'A drum line', - 'passes by walking down the street playing their instruments.', - 'has heard approaching them.', - "arrives and they're outside dancing and asleep.", - 'turns the lead singer watches the performance.', - ) - print(e) +def read_swag_examples(input_file, is_training): + input_df = pd.read_csv(input_file) - e = SwagExample( - 3416, - 'Members of the procession walk down the street holding small horn brass instruments.', - 'A drum line', - 'passes by walking down the street playing their instruments.', - 'has heard approaching them.', - "arrives and they're outside dancing and asleep.", - 'turns the lead singer watches the performance.', - 0 - ) - print(e) + if is_training and 'label' not in input_df.columns: + raise ValueError( + "For training, the input file must contain a label column.") + + examples = [ + SwagExample( + swag_id = row['fold-ind'], + context_sentence = row['sent1'], + start_ending = row['sent2'], + ending_0 = row['ending0'], + ending_1 = row['ending1'], + ending_2 = row['ending2'], + ending_3 = row['ending3'], + label = row['label'] if is_training else None + ) for _, row in input_df.iterrows() + ] + + return examples + + +if __name__ == "__main__": + examples = read_swag_examples('data/train.csv', True) + print(len(examples)) + for example in examples[:5]: + print('###########################') + print(example)