python 2 compatibility
This commit is contained in:
@@ -17,26 +17,26 @@
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import sys
|
||||
import functools
|
||||
import argparse
|
||||
import logging
|
||||
import time
|
||||
import math
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
||||
|
||||
def logging(s, log_path, print_=True, log_=True):
|
||||
if print_:
|
||||
print(s)
|
||||
if log_:
|
||||
with open(log_path, 'a+') as f_log:
|
||||
f_log.write(s + '\n')
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
level = logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def get_logger(log_path, **kwargs):
|
||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
||||
|
||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||
@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
|
||||
device = torch.device("cuda" if args.cuda else "cpu")
|
||||
|
||||
# Get logger
|
||||
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||
log_=not args.no_log)
|
||||
# logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||
# log_=not args.no_log)
|
||||
|
||||
# Load dataset
|
||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||
@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
||||
model = TransfoXLModel.from_pretrained(args.model_name)
|
||||
model = model.to(device)
|
||||
|
||||
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||
|
||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||
@@ -116,7 +116,7 @@ def evaluate(eval_iter):
|
||||
total_loss += seq_len * loss.item()
|
||||
total_len += seq_len
|
||||
total_time = time.time() - start_time
|
||||
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||
total_time, 1000 * total_time / (idx+1)))
|
||||
return total_loss / total_len
|
||||
|
||||
@@ -146,6 +146,6 @@ if valid_loss is not None:
|
||||
if test_loss is not None:
|
||||
log_str += format_log(test_loss, 'test')
|
||||
|
||||
logging('=' * 100)
|
||||
logging(log_str)
|
||||
logging('=' * 100)
|
||||
logger.info('=' * 100)
|
||||
logger.info(log_str)
|
||||
logger.info('=' * 100)
|
||||
|
||||
@@ -15,26 +15,27 @@
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import csv
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -91,10 +92,12 @@ class DataProcessor(object):
|
||||
@classmethod
|
||||
def _read_tsv(cls, input_file, quotechar=None):
|
||||
"""Reads a tab separated value file."""
|
||||
with open(input_file, "r", encoding='utf-8') as f:
|
||||
with open(input_file, "rb") as f:
|
||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
return lines
|
||||
|
||||
@@ -429,7 +432,8 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
task_name = args.task_name.lower()
|
||||
|
||||
@@ -451,7 +455,7 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||
num_labels = num_labels)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
|
||||
@@ -15,26 +15,23 @@
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import logging
|
||||
import argparse
|
||||
from tqdm import tqdm, trange
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
import random
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt='%m/%d/%Y %H:%M:%S',
|
||||
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
|
||||
if self.line_buffer is None:
|
||||
# read first non-empty line of file
|
||||
while t1 == "" :
|
||||
t1 = self.file.__next__().strip()
|
||||
t2 = self.file.__next__().strip()
|
||||
t1 = next(self.file).strip()
|
||||
t2 = next(self.file).strip()
|
||||
else:
|
||||
# use t2 from previous iteration as new t1
|
||||
t1 = self.line_buffer
|
||||
t2 = self.file.__next__().strip()
|
||||
t2 = next(self.file).strip()
|
||||
# skip empty rows that are used for separating documents and keep track of current doc id
|
||||
while t2 == "" or t1 == "":
|
||||
t1 = self.file.__next__().strip()
|
||||
t2 = self.file.__next__().strip()
|
||||
t1 = next(self.file).strip()
|
||||
t2 = next(self.file).strip()
|
||||
self.current_doc = self.current_doc+1
|
||||
self.line_buffer = t2
|
||||
|
||||
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
|
||||
def get_next_line(self):
|
||||
""" Gets next line of random_file and starts over when reaching end of file"""
|
||||
try:
|
||||
line = self.random_file.__next__().strip()
|
||||
line = next(self.random_file).strip()
|
||||
#keep track of which document we are currently looking at to later avoid having the same doc as t1
|
||||
if line == "":
|
||||
self.current_random_doc = self.current_random_doc + 1
|
||||
line = self.random_file.__next__().strip()
|
||||
line = next(self.random_file).strip()
|
||||
except StopIteration:
|
||||
self.random_file.close()
|
||||
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
|
||||
line = self.random_file.__next__().strip()
|
||||
line = next(self.random_file).strip()
|
||||
return line
|
||||
|
||||
|
||||
@@ -425,6 +422,7 @@ def main():
|
||||
help="The output directory where the model checkpoints will be written.")
|
||||
|
||||
## Other parameters
|
||||
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
@@ -513,7 +511,8 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
@@ -579,7 +578,7 @@ def main():
|
||||
if args.local_rank == -1:
|
||||
train_sampler = RandomSampler(train_dataset)
|
||||
else:
|
||||
#TODO: check if this works with current data generator from disk that relies on file.__next__
|
||||
#TODO: check if this works with current data generator from disk that relies on next(file)
|
||||
# (it doesn't return item back by index)
|
||||
train_sampler = DistributedSampler(train_dataset)
|
||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||
|
||||
@@ -15,29 +15,36 @@
|
||||
# limitations under the License.
|
||||
"""Run BERT on SQuAD."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import collections
|
||||
import logging
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import pickle
|
||||
from tqdm import tqdm, trange
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||
BertTokenizer,
|
||||
whitespace_tokenize)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -784,7 +791,8 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory () already exists and is not empty.")
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
@@ -798,7 +806,7 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
||||
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
|
||||
|
||||
if args.fp16:
|
||||
model.half()
|
||||
|
||||
@@ -15,22 +15,25 @@
|
||||
# limitations under the License.
|
||||
"""BERT finetuning runner."""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import logging
|
||||
import os
|
||||
import argparse
|
||||
import random
|
||||
from tqdm import tqdm, trange
|
||||
import csv
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
||||
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||
TensorDataset)
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm, trange
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||
from pytorch_pretrained_bert.optimization import BertAdam
|
||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||
|
||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||
@@ -65,17 +68,17 @@ class SwagExample(object):
|
||||
|
||||
def __repr__(self):
|
||||
l = [
|
||||
f"swag_id: {self.swag_id}",
|
||||
f"context_sentence: {self.context_sentence}",
|
||||
f"start_ending: {self.start_ending}",
|
||||
f"ending_0: {self.endings[0]}",
|
||||
f"ending_1: {self.endings[1]}",
|
||||
f"ending_2: {self.endings[2]}",
|
||||
f"ending_3: {self.endings[3]}",
|
||||
"swag_id: {}".format(self.swag_id),
|
||||
"context_sentence: {}".format(self.context_sentence),
|
||||
"start_ending: {}".format(self.start_ending),
|
||||
"ending_0: {}".format(self.endings[0]),
|
||||
"ending_1: {}".format(self.endings[1]),
|
||||
"ending_2: {}".format(self.endings[2]),
|
||||
"ending_3: {}".format(self.endings[3]),
|
||||
]
|
||||
|
||||
if self.label is not None:
|
||||
l.append(f"label: {self.label}")
|
||||
l.append("label: {}".format(self.label))
|
||||
|
||||
return ", ".join(l)
|
||||
|
||||
@@ -102,7 +105,11 @@ class InputFeatures(object):
|
||||
def read_swag_examples(input_file, is_training):
|
||||
with open(input_file, 'r', encoding='utf-8') as f:
|
||||
reader = csv.reader(f)
|
||||
lines = list(reader)
|
||||
lines = []
|
||||
for line in reader:
|
||||
if sys.version_info[0] == 2:
|
||||
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||
lines.append(line)
|
||||
|
||||
if is_training and lines[0][-1] != 'label':
|
||||
raise ValueError(
|
||||
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
||||
label = example.label
|
||||
if example_index < 5:
|
||||
logger.info("*** Example ***")
|
||||
logger.info(f"swag_id: {example.swag_id}")
|
||||
logger.info("swag_id: {}".format(example.swag_id))
|
||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||
logger.info(f"choice: {choice_idx}")
|
||||
logger.info(f"tokens: {' '.join(tokens)}")
|
||||
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
||||
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
||||
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
||||
logger.info("choice: {}".format(choice_idx))
|
||||
logger.info("tokens: {}".format(' '.join(tokens)))
|
||||
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
||||
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
||||
if is_training:
|
||||
logger.info(f"label: {label}")
|
||||
logger.info("label: {}".format(label))
|
||||
|
||||
features.append(
|
||||
InputFeatures(
|
||||
@@ -349,7 +356,8 @@ def main():
|
||||
|
||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||
|
||||
@@ -362,7 +370,7 @@ def main():
|
||||
|
||||
# Prepare model
|
||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
||||
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||
num_choices=4)
|
||||
if args.fp16:
|
||||
model.half()
|
||||
|
||||
@@ -15,7 +15,7 @@ def main():
|
||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
@@ -43,7 +43,7 @@ def main():
|
||||
else:
|
||||
try:
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||
"In that case, it requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
|
||||
@@ -14,14 +14,18 @@
|
||||
# limitations under the License.
|
||||
"""Convert OpenAI GPT checkpoint."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
|
||||
from pytorch_pretrained_bert.modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
||||
from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
|
||||
OpenAIGPTConfig,
|
||||
OpenAIGPTModel,
|
||||
load_tf_weights_in_openai_gpt)
|
||||
|
||||
|
||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||
# Construct model
|
||||
|
||||
@@ -14,25 +14,31 @@
|
||||
# limitations under the License.
|
||||
"""Convert Transformer XL checkpoint and datasets."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import argparse
|
||||
import pickle
|
||||
from io import open
|
||||
|
||||
import tensorflow as tf
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
||||
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
||||
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
|
||||
WEIGHTS_NAME,
|
||||
TransfoXLConfig,
|
||||
TransfoXLModel,
|
||||
load_tf_weights_in_transfo_xl)
|
||||
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
|
||||
VOCAB_NAME)
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
# We do this to be able to load the python 2 datasets pickles
|
||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||
sys.modules['data_utils'] = data_utils
|
||||
|
||||
@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
|
||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||
Copyright by the AllenNLP authors.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import json
|
||||
from urllib.parse import urlparse
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple, Union, IO, Callable, Set
|
||||
from hashlib import sha256
|
||||
from functools import wraps
|
||||
|
||||
from tqdm import tqdm
|
||||
from hashlib import sha256
|
||||
from io import open
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
import requests
|
||||
from botocore.exceptions import ClientError
|
||||
from tqdm import tqdm
|
||||
|
||||
try:
|
||||
from urllib.parse import urlparse
|
||||
except ImportError:
|
||||
from urlparse import urlparse
|
||||
|
||||
try:
|
||||
from pathlib import Path
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
Path.home() / '.pytorch_pretrained_bert'))
|
||||
except ImportError:
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||
|
||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||
Path.home() / '.pytorch_pretrained_bert'))
|
||||
|
||||
|
||||
def url_to_filename(url: str, etag: str = None) -> str:
|
||||
def url_to_filename(url, etag=None):
|
||||
"""
|
||||
Convert `url` into a hashed filename in a repeatable way.
|
||||
If `etag` is specified, append its hash to the url's, delimited
|
||||
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
|
||||
return filename
|
||||
|
||||
|
||||
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
|
||||
def filename_to_url(filename, cache_dir=None):
|
||||
"""
|
||||
Return the url and etag (which may be ``None``) stored for `filename`.
|
||||
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
|
||||
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
cache_path = os.path.join(cache_dir, filename)
|
||||
if not os.path.exists(cache_path):
|
||||
raise FileNotFoundError("file {} not found".format(cache_path))
|
||||
raise EnvironmentError("file {} not found".format(cache_path))
|
||||
|
||||
meta_path = cache_path + '.json'
|
||||
if not os.path.exists(meta_path):
|
||||
raise FileNotFoundError("file {} not found".format(meta_path))
|
||||
raise EnvironmentError("file {} not found".format(meta_path))
|
||||
|
||||
with open(meta_path) as meta_file:
|
||||
with open(meta_path, encoding="utf-8") as meta_file:
|
||||
metadata = json.load(meta_file)
|
||||
url = metadata['url']
|
||||
etag = metadata['etag']
|
||||
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
|
||||
return url, etag
|
||||
|
||||
|
||||
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
|
||||
def cached_path(url_or_filename, cache_dir=None):
|
||||
"""
|
||||
Given something that might be a URL (or might be a local path),
|
||||
determine which. If it's a URL, download the file and cache it, and
|
||||
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(url_or_filename, Path):
|
||||
url_or_filename = str(url_or_filename)
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
parsed = urlparse(url_or_filename)
|
||||
|
||||
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
|
||||
return url_or_filename
|
||||
elif parsed.scheme == '':
|
||||
# File, but it doesn't exist.
|
||||
raise FileNotFoundError("file {} not found".format(url_or_filename))
|
||||
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||
else:
|
||||
# Something unknown
|
||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
||||
|
||||
|
||||
def split_s3_path(url: str) -> Tuple[str, str]:
|
||||
def split_s3_path(url):
|
||||
"""Split a full s3 path into the bucket name and path."""
|
||||
parsed = urlparse(url)
|
||||
if not parsed.netloc or not parsed.path:
|
||||
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
|
||||
return bucket_name, s3_path
|
||||
|
||||
|
||||
def s3_request(func: Callable):
|
||||
def s3_request(func):
|
||||
"""
|
||||
Wrapper function for s3 requests in order to create more helpful error
|
||||
messages.
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(url: str, *args, **kwargs):
|
||||
def wrapper(url, *args, **kwargs):
|
||||
try:
|
||||
return func(url, *args, **kwargs)
|
||||
except ClientError as exc:
|
||||
if int(exc.response["Error"]["Code"]) == 404:
|
||||
raise FileNotFoundError("file {} not found".format(url))
|
||||
raise EnvironmentError("file {} not found".format(url))
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_etag(url: str) -> Optional[str]:
|
||||
def s3_etag(url):
|
||||
"""Check ETag on S3 object."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
|
||||
|
||||
|
||||
@s3_request
|
||||
def s3_get(url: str, temp_file: IO) -> None:
|
||||
def s3_get(url, temp_file):
|
||||
"""Pull a file directly from S3."""
|
||||
s3_resource = boto3.resource("s3")
|
||||
bucket_name, s3_path = split_s3_path(url)
|
||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||
|
||||
|
||||
def http_get(url: str, temp_file: IO) -> None:
|
||||
def http_get(url, temp_file):
|
||||
req = requests.get(url, stream=True)
|
||||
content_length = req.headers.get('Content-Length')
|
||||
total = int(content_length) if content_length is not None else None
|
||||
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
|
||||
progress.close()
|
||||
|
||||
|
||||
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
||||
def get_from_cache(url, cache_dir=None):
|
||||
"""
|
||||
Given a URL, look for the corresponding dataset in the local cache.
|
||||
If it's not there, download it. Then return the path to the cached file.
|
||||
"""
|
||||
if cache_dir is None:
|
||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||
if isinstance(cache_dir, Path):
|
||||
cache_dir = str(cache_dir)
|
||||
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
if not os.path.exists(cache_dir):
|
||||
os.makedirs(cache_dir)
|
||||
|
||||
# Get eTag to add to filename, if it exists.
|
||||
if url.startswith("s3://"):
|
||||
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
||||
logger.info("creating metadata file for %s", cache_path)
|
||||
meta = {'url': url, 'etag': etag}
|
||||
meta_path = cache_path + '.json'
|
||||
with open(meta_path, 'w') as meta_file:
|
||||
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
||||
json.dump(meta, meta_file)
|
||||
|
||||
logger.info("removing temp file %s", temp_file.name)
|
||||
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
||||
return cache_path
|
||||
|
||||
|
||||
def read_set_from_file(filename: str) -> Set[str]:
|
||||
def read_set_from_file(filename):
|
||||
'''
|
||||
Extract a de-duped collection (set) of text from a file.
|
||||
Expected file format is one item per line.
|
||||
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
|
||||
return collection
|
||||
|
||||
|
||||
def get_file_extension(path: str, dot=True, lower: bool = True):
|
||||
def get_file_extension(path, dot=True, lower=True):
|
||||
ext = os.path.splitext(path)[1]
|
||||
ext = ext if dot else ext[1:]
|
||||
return ext.lower() if lower else ext
|
||||
|
||||
@@ -15,18 +15,18 @@
|
||||
# limitations under the License.
|
||||
"""PyTorch BERT model."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import copy
|
||||
import json
|
||||
import math
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
||||
import re
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
@@ -164,7 +164,8 @@ class BertConfig(object):
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertIntermediate, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.dense(hidden_states)
|
||||
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
|
||||
def __init__(self, config):
|
||||
super(BertPredictionHeadTransform, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
||||
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||
else:
|
||||
self.transform_act_fn = config.hidden_act
|
||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
|
||||
@@ -24,6 +24,8 @@ import os
|
||||
import shutil
|
||||
import tarfile
|
||||
import tempfile
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
|
||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||
initializing all weight matrices.
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||
for block in self.h:
|
||||
hidden_states = block(hidden_states)
|
||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
||||
output_shape = input_shape + (hidden_states.size(-1),)
|
||||
return hidden_states.view(*output_shape)
|
||||
|
||||
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
|
||||
@@ -27,6 +27,8 @@ import tarfile
|
||||
import tempfile
|
||||
import shutil
|
||||
import collections
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
||||
try:
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
except ModuleNotFoundError:
|
||||
except ImportError:
|
||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise
|
||||
@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
|
||||
proj_init_std: parameters initialized by N(0, init_std)
|
||||
init_std: parameters initialized by N(0, init_std)
|
||||
"""
|
||||
if isinstance(vocab_size_or_config_json_file, str):
|
||||
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||
json_config = json.loads(reader.read())
|
||||
for key, value in json_config.items():
|
||||
@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
|
||||
return x
|
||||
|
||||
def _rel_shift(self, x, zero_triu=False):
|
||||
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
||||
device=x.device, dtype=x.dtype)
|
||||
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
|
||||
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||
|
||||
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
||||
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
|
||||
x_padded = x_padded.view(*x_padded_shape)
|
||||
|
||||
x = x_padded[1:].view_as(x)
|
||||
|
||||
@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
|
||||
|
||||
emb_flat.index_copy_(0, indices_i, emb_i)
|
||||
|
||||
embed = emb_flat.view(*inp.size(), self.d_proj)
|
||||
embed_shape = inp.size() + (self.d_proj,)
|
||||
embed = emb_flat.view(embed_shape)
|
||||
|
||||
embed.mul_(self.emb_scale)
|
||||
|
||||
@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
||||
try:
|
||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
|
||||
@@ -14,14 +14,13 @@
|
||||
# limitations under the License.
|
||||
"""Tokenization classes."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import collections
|
||||
import unicodedata
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
import unicodedata
|
||||
from io import open
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
@@ -129,7 +128,7 @@ class BertTokenizer(object):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
|
||||
@@ -13,11 +13,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tokenization classes for OpenAI GPT."""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
import sys
|
||||
from io import open
|
||||
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.fix_text = ftfy.fix_text
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
if isinstance(tokens, str):
|
||||
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||
if tokens in self.special_tokens:
|
||||
return self.special_tokens[tokens]
|
||||
else:
|
||||
|
||||
@@ -16,16 +16,27 @@
|
||||
""" Tokenization classes for Transformer XL model.
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
"""
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
|
||||
import os
|
||||
import glob
|
||||
import logging
|
||||
import pickle
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
from collections import Counter, OrderedDict
|
||||
from io import open
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from .file_utils import cached_path
|
||||
|
||||
if sys.version_info[0] == 2:
|
||||
import cPickle as pickle
|
||||
else:
|
||||
import pickle
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
except EnvironmentError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
|
||||
3
setup.py
3
setup.py
@@ -33,6 +33,7 @@ To create the package for pypi.
|
||||
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||
|
||||
"""
|
||||
from io import open
|
||||
from setuptools import find_packages, setup
|
||||
|
||||
setup(
|
||||
@@ -58,7 +59,7 @@ setup(
|
||||
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
|
||||
]
|
||||
},
|
||||
python_requires='>=3.5.0',
|
||||
# python_requires='>=3.5.0',
|
||||
tests_require=['pytest'],
|
||||
classifiers=[
|
||||
'Intended Audience :: Science/Research',
|
||||
|
||||
@@ -18,6 +18,7 @@ from __future__ import print_function
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from io import open
|
||||
|
||||
from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
|
||||
_is_whitespace, _is_control, _is_punctuation)
|
||||
@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
|
||||
vocab_file = vocab_writer.name
|
||||
@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
|
||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||
"##ing", ","
|
||||
]
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
||||
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||
vocab_file = vocab_writer.name
|
||||
|
||||
|
||||
Reference in New Issue
Block a user