Adding read_swag_examples to load the dataset.
This commit is contained in:
@@ -14,6 +14,9 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
|
||||||
class SwagExample(object):
|
class SwagExample(object):
|
||||||
"""A single training/test example for the SWAG dataset."""
|
"""A single training/test example for the SWAG dataset."""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@@ -53,26 +56,32 @@ class SwagExample(object):
|
|||||||
|
|
||||||
return ', '.join(l)
|
return ', '.join(l)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
def read_swag_examples(input_file, is_training):
|
||||||
e = SwagExample(
|
input_df = pd.read_csv(input_file)
|
||||||
3416,
|
|
||||||
'Members of the procession walk down the street holding small horn brass instruments.',
|
|
||||||
'A drum line',
|
|
||||||
'passes by walking down the street playing their instruments.',
|
|
||||||
'has heard approaching them.',
|
|
||||||
"arrives and they're outside dancing and asleep.",
|
|
||||||
'turns the lead singer watches the performance.',
|
|
||||||
)
|
|
||||||
print(e)
|
|
||||||
|
|
||||||
e = SwagExample(
|
if is_training and 'label' not in input_df.columns:
|
||||||
3416,
|
raise ValueError(
|
||||||
'Members of the procession walk down the street holding small horn brass instruments.',
|
"For training, the input file must contain a label column.")
|
||||||
'A drum line',
|
|
||||||
'passes by walking down the street playing their instruments.',
|
examples = [
|
||||||
'has heard approaching them.',
|
SwagExample(
|
||||||
"arrives and they're outside dancing and asleep.",
|
swag_id = row['fold-ind'],
|
||||||
'turns the lead singer watches the performance.',
|
context_sentence = row['sent1'],
|
||||||
0
|
start_ending = row['sent2'],
|
||||||
)
|
ending_0 = row['ending0'],
|
||||||
print(e)
|
ending_1 = row['ending1'],
|
||||||
|
ending_2 = row['ending2'],
|
||||||
|
ending_3 = row['ending3'],
|
||||||
|
label = row['label'] if is_training else None
|
||||||
|
) for _, row in input_df.iterrows()
|
||||||
|
]
|
||||||
|
|
||||||
|
return examples
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
examples = read_swag_examples('data/train.csv', True)
|
||||||
|
print(len(examples))
|
||||||
|
for example in examples[:5]:
|
||||||
|
print('###########################')
|
||||||
|
print(example)
|
||||||
|
|||||||
Reference in New Issue
Block a user