Removing the dependency to pandas and using the csv module to load data.
This commit is contained in:
@@ -14,13 +14,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import argparse
|
import argparse
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
|
import csv
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -100,25 +99,28 @@ class InputFeatures(object):
|
|||||||
|
|
||||||
|
|
||||||
def read_swag_examples(input_file, is_training):
|
def read_swag_examples(input_file, is_training):
|
||||||
input_df = pd.read_csv(input_file)
|
with open(input_file, 'r') as f:
|
||||||
|
reader = csv.reader(f)
|
||||||
|
lines = list(reader)
|
||||||
|
|
||||||
if is_training and 'label' not in input_df.columns:
|
if is_training and lines[0][-1] != 'label':
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"For training, the input file must contain a label column.")
|
"For training, the input file must contain a label column."
|
||||||
|
)
|
||||||
|
|
||||||
examples = [
|
examples = [
|
||||||
SwagExample(
|
SwagExample(
|
||||||
swag_id = row['fold-ind'],
|
swag_id = line[2],
|
||||||
context_sentence = row['sent1'],
|
context_sentence = line[4],
|
||||||
start_ending = row['sent2'], # in the swag dataset, the
|
start_ending = line[5], # in the swag dataset, the
|
||||||
# common beginning of each
|
# common beginning of each
|
||||||
# choice is stored in "sent2".
|
# choice is stored in "sent2".
|
||||||
ending_0 = row['ending0'],
|
ending_0 = line[7],
|
||||||
ending_1 = row['ending1'],
|
ending_1 = line[8],
|
||||||
ending_2 = row['ending2'],
|
ending_2 = line[9],
|
||||||
ending_3 = row['ending3'],
|
ending_3 = line[10],
|
||||||
label = row['label'] if is_training else None
|
label = int(line[11]) if is_training else None
|
||||||
) for _, row in input_df.iterrows()
|
) for line in lines[1:] # we skip the line with the column names
|
||||||
]
|
]
|
||||||
|
|
||||||
return examples
|
return examples
|
||||||
|
|||||||
Reference in New Issue
Block a user