python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -15,22 +15,25 @@
# limitations under the License.
"""BERT finetuning runner."""
import argparse
import csv
import logging
import os
import argparse
import random
from tqdm import tqdm, trange
import csv
import sys
from io import open
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@@ -65,17 +68,17 @@ class SwagExample(object):
def __repr__(self):
l = [
f"swag_id: {self.swag_id}",
f"context_sentence: {self.context_sentence}",
f"start_ending: {self.start_ending}",
f"ending_0: {self.endings[0]}",
f"ending_1: {self.endings[1]}",
f"ending_2: {self.endings[2]}",
f"ending_3: {self.endings[3]}",
"swag_id: {}".format(self.swag_id),
"context_sentence: {}".format(self.context_sentence),
"start_ending: {}".format(self.start_ending),
"ending_0: {}".format(self.endings[0]),
"ending_1: {}".format(self.endings[1]),
"ending_2: {}".format(self.endings[2]),
"ending_3: {}".format(self.endings[3]),
]
if self.label is not None:
l.append(f"label: {self.label}")
l.append("label: {}".format(self.label))
return ", ".join(l)
@@ -102,7 +105,11 @@ class InputFeatures(object):
def read_swag_examples(input_file, is_training):
with open(input_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
lines = list(reader)
lines = []
for line in reader:
if sys.version_info[0] == 2:
line = list(unicode(cell, 'utf-8') for cell in line)
lines.append(line)
if is_training and lines[0][-1] != 'label':
raise ValueError(
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
label = example.label
if example_index < 5:
logger.info("*** Example ***")
logger.info(f"swag_id: {example.swag_id}")
logger.info("swag_id: {}".format(example.swag_id))
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
logger.info(f"choice: {choice_idx}")
logger.info(f"tokens: {' '.join(tokens)}")
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
logger.info("choice: {}".format(choice_idx))
logger.info("tokens: {}".format(' '.join(tokens)))
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
if is_training:
logger.info(f"label: {label}")
logger.info("label: {}".format(label))
features.append(
InputFeatures(
@@ -349,7 +356,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
@@ -362,7 +370,7 @@ def main():
# Prepare model
model = BertForMultipleChoice.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
num_choices=4)
if args.fp16:
model.half()