python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -15,26 +15,23 @@
# limitations under the License.
"""BERT finetuning runner."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import logging
import argparse
from tqdm import tqdm, trange
import logging
import os
import random
from io import open
import numpy as np
import torch
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForPreTraining
from pytorch_pretrained_bert.optimization import BertAdam
from torch.utils.data import Dataset
import random
from pytorch_pretrained_bert.tokenization import BertTokenizer
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt='%m/%d/%Y %H:%M:%S',
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
if self.line_buffer is None:
# read first non-empty line of file
while t1 == "" :
t1 = self.file.__next__().strip()
t2 = self.file.__next__().strip()
t1 = next(self.file).strip()
t2 = next(self.file).strip()
else:
# use t2 from previous iteration as new t1
t1 = self.line_buffer
t2 = self.file.__next__().strip()
t2 = next(self.file).strip()
# skip empty rows that are used for separating documents and keep track of current doc id
while t2 == "" or t1 == "":
t1 = self.file.__next__().strip()
t2 = self.file.__next__().strip()
t1 = next(self.file).strip()
t2 = next(self.file).strip()
self.current_doc = self.current_doc+1
self.line_buffer = t2
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
def get_next_line(self):
""" Gets next line of random_file and starts over when reaching end of file"""
try:
line = self.random_file.__next__().strip()
line = next(self.random_file).strip()
#keep track of which document we are currently looking at to later avoid having the same doc as t1
if line == "":
self.current_random_doc = self.current_random_doc + 1
line = self.random_file.__next__().strip()
line = next(self.random_file).strip()
except StopIteration:
self.random_file.close()
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
line = self.random_file.__next__().strip()
line = next(self.random_file).strip()
return line
@@ -425,6 +422,7 @@ def main():
help="The output directory where the model checkpoints will be written.")
## Other parameters
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
parser.add_argument("--max_seq_length",
default=128,
type=int,
@@ -513,7 +511,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
os.makedirs(args.output_dir, exist_ok=True)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
@@ -579,7 +578,7 @@ def main():
if args.local_rank == -1:
train_sampler = RandomSampler(train_dataset)
else:
#TODO: check if this works with current data generator from disk that relies on file.__next__
#TODO: check if this works with current data generator from disk that relies on next(file)
# (it doesn't return item back by index)
train_sampler = DistributedSampler(train_dataset)
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
@@ -643,4 +642,4 @@ def accuracy(out, labels):
if __name__ == "__main__":
main()
main()