python 2 compatibility
This commit is contained in:
@@ -15,22 +15,25 @@
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
import csv
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -65,17 +68,17 @@ class SwagExample(object):
|
||||
|
||||
def __repr__(self):
|
||||
l = [
|
||||
f"swag_id: {self.swag_id}",
|
||||
f"context_sentence: {self.context_sentence}",
|
||||
f"start_ending: {self.start_ending}",
|
||||
f"ending_0: {self.endings[0]}",
|
||||
f"ending_1: {self.endings[1]}",
|
||||
f"ending_2: {self.endings[2]}",
|
||||
f"ending_3: {self.endings[3]}",
|
||||
"swag_id: {}".format(self.swag_id),
|
||||
"context_sentence: {}".format(self.context_sentence),
|
||||
"start_ending: {}".format(self.start_ending),
|
||||
"ending_0: {}".format(self.endings[0]),
|
||||
"ending_1: {}".format(self.endings[1]),
|
||||
"ending_2: {}".format(self.endings[2]),
|
||||
"ending_3: {}".format(self.endings[3]),
|
||||
]
|
||||
|
||||
if self.label is not None:
|
||||
l.append(f"label: {self.label}")
|
||||
l.append("label: {}".format(self.label))
|
||||
|
||||
return ", ".join(l)
|
||||
|
||||
@@ -102,7 +105,11 @@ class InputFeatures(object):
|
||||
def read_swag_examples(input_file, is_training):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
lines = list(reader)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
|
||||
if is_training and lines[0][-1] != 'label':
|
||||
raise ValueError(
|
||||
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
label = example.label
|
||||
if example_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info(f"swag_id: {example.swag_id}")
|
||||
logger.info("swag_id: {}".format(example.swag_id))
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info(f"choice: {choice_idx}")
|
||||
logger.info(f"tokens: {' '.join(tokens)}")
|
||||
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
||||
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
||||
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("tokens: {}".format(' '.join(tokens)))
|
||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
||||
if is_training:
|
||||
logger.info(f"label: {label}")
|
||||
logger.info("label: {}".format(label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
@@ -349,7 +356,8 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
@@ -362,7 +370,7 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||
num_choices=4)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
|
||||
Reference in New Issue
Block a user