python 2 compatibility

This commit is contained in:
thomwolf
2019-02-06 00:07:46 +01:00
parent ba37ddc5ce
commit 448937c00d
17 changed files with 246 additions and 184 deletions

View File

@@ -15,29 +15,36 @@
# limitations under the License.
"""Run BERT on SQuAD."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import absolute_import, division, print_function
import argparse
import collections
import logging
import json
import logging
import math
import os
import random
import pickle
from tqdm import tqdm, trange
import sys
from io import open
import numpy as np
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
BertTokenizer,
whitespace_tokenize)
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S',
@@ -784,7 +791,8 @@ def main():
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
raise ValueError("Output directory () already exists and is not empty.")
os.makedirs(args.output_dir, exist_ok=True)
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
@@ -798,7 +806,7 @@ def main():
# Prepare model
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
if args.fp16:
model.half()