From df34f22854a5174f5ad941c72255098e1b47e1bd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gr=C3=A9gory=20Ch=C3=A2tel?= Date: Mon, 10 Dec 2018 17:45:23 +0100 Subject: [PATCH] Removing the dependency to pandas and using the csv module to load data. --- examples/run_swag.py | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/examples/run_swag.py b/examples/run_swag.py index 201317766f..88297bf801 100644 --- a/examples/run_swag.py +++ b/examples/run_swag.py @@ -14,13 +14,12 @@ # limitations under the License. """BERT finetuning runner.""" -import pandas as pd - import logging import os import argparse import random from tqdm import tqdm, trange +import csv import numpy as np import torch @@ -100,25 +99,28 @@ class InputFeatures(object): def read_swag_examples(input_file, is_training): - input_df = pd.read_csv(input_file) + with open(input_file, 'r') as f: + reader = csv.reader(f) + lines = list(reader) - if is_training and 'label' not in input_df.columns: + if is_training and lines[0][-1] != 'label': raise ValueError( - "For training, the input file must contain a label column.") + "For training, the input file must contain a label column." + ) examples = [ SwagExample( - swag_id = row['fold-ind'], - context_sentence = row['sent1'], - start_ending = row['sent2'], # in the swag dataset, the + swag_id = line[2], + context_sentence = line[4], + start_ending = line[5], # in the swag dataset, the # common beginning of each # choice is stored in "sent2". - ending_0 = row['ending0'], - ending_1 = row['ending1'], - ending_2 = row['ending2'], - ending_3 = row['ending3'], - label = row['label'] if is_training else None - ) for _, row in input_df.iterrows() + ending_0 = line[7], + ending_1 = line[8], + ending_2 = line[9], + ending_3 = line[10], + label = int(line[11]) if is_training else None + ) for line in lines[1:] # we skip the line with the column names ] return examples