python 2 compatibility
This commit is contained in:
@@ -17,26 +17,26 @@
|
|||||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/eval.py
|
||||||
"""
|
"""
|
||||||
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import functools
|
import functools
|
||||||
import argparse
|
import argparse
|
||||||
|
import logging
|
||||||
import time
|
import time
|
||||||
import math
|
import math
|
||||||
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
from pytorch_pretrained_bert import TransfoXLModel, TransfoXLCorpus
|
||||||
|
|
||||||
def logging(s, log_path, print_=True, log_=True):
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
if print_:
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
print(s)
|
level = logging.INFO)
|
||||||
if log_:
|
logger = logging.getLogger(__name__)
|
||||||
with open(log_path, 'a+') as f_log:
|
|
||||||
f_log.write(s + '\n')
|
|
||||||
|
|
||||||
def get_logger(log_path, **kwargs):
|
|
||||||
return functools.partial(logging, log_path=log_path, **kwargs)
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model')
|
||||||
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
# parser.add_argument('--data', type=str, default='../data/wikitext-103',
|
||||||
@@ -71,8 +71,8 @@ assert args.ext_len >= 0, 'extended context length must be non-negative'
|
|||||||
device = torch.device("cuda" if args.cuda else "cpu")
|
device = torch.device("cuda" if args.cuda else "cpu")
|
||||||
|
|
||||||
# Get logger
|
# Get logger
|
||||||
logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
# logging = get_logger(os.path.join(args.work_dir, 'log.txt'),
|
||||||
log_=not args.no_log)
|
# log_=not args.no_log)
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||||
@@ -90,7 +90,7 @@ te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
|
|||||||
model = TransfoXLModel.from_pretrained(args.model_name)
|
model = TransfoXLModel.from_pretrained(args.model_name)
|
||||||
model = model.to(device)
|
model = model.to(device)
|
||||||
|
|
||||||
logging('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format(
|
||||||
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len))
|
||||||
|
|
||||||
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
|
||||||
@@ -116,7 +116,7 @@ def evaluate(eval_iter):
|
|||||||
total_loss += seq_len * loss.item()
|
total_loss += seq_len * loss.item()
|
||||||
total_len += seq_len
|
total_len += seq_len
|
||||||
total_time = time.time() - start_time
|
total_time = time.time() - start_time
|
||||||
logging('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format(
|
||||||
total_time, 1000 * total_time / (idx+1)))
|
total_time, 1000 * total_time / (idx+1)))
|
||||||
return total_loss / total_len
|
return total_loss / total_len
|
||||||
|
|
||||||
@@ -146,6 +146,6 @@ if valid_loss is not None:
|
|||||||
if test_loss is not None:
|
if test_loss is not None:
|
||||||
log_str += format_log(test_loss, 'test')
|
log_str += format_log(test_loss, 'test')
|
||||||
|
|
||||||
logging('=' * 100)
|
logger.info('=' * 100)
|
||||||
logging(log_str)
|
logger.info(log_str)
|
||||||
logging('=' * 100)
|
logger.info('=' * 100)
|
||||||
|
|||||||
@@ -15,26 +15,27 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import csv
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import csv
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm, trange
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
|
TensorDataset)
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
from pytorch_pretrained_bert.modeling import BertForSequenceClassification
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -91,10 +92,12 @@ class DataProcessor(object):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def _read_tsv(cls, input_file, quotechar=None):
|
def _read_tsv(cls, input_file, quotechar=None):
|
||||||
"""Reads a tab separated value file."""
|
"""Reads a tab separated value file."""
|
||||||
with open(input_file, "r", encoding='utf-8') as f:
|
with open(input_file, "rb") as f:
|
||||||
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
|
||||||
lines = []
|
lines = []
|
||||||
for line in reader:
|
for line in reader:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||||
lines.append(line)
|
lines.append(line)
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
@@ -429,7 +432,8 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
task_name = args.task_name.lower()
|
task_name = args.task_name.lower()
|
||||||
|
|
||||||
@@ -451,7 +455,7 @@ def main():
|
|||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
model = BertForSequenceClassification.from_pretrained(args.bert_model,
|
||||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||||
num_labels = num_labels)
|
num_labels = num_labels)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|||||||
@@ -15,26 +15,23 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import logging
|
|
||||||
import argparse
|
import argparse
|
||||||
from tqdm import tqdm, trange
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, Dataset, RandomSampler
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
|
||||||
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
from pytorch_pretrained_bert.modeling import BertForPreTraining
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
from torch.utils.data import Dataset
|
|
||||||
import random
|
|
||||||
|
|
||||||
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt='%m/%d/%Y %H:%M:%S',
|
datefmt='%m/%d/%Y %H:%M:%S',
|
||||||
@@ -185,16 +182,16 @@ class BERTDataset(Dataset):
|
|||||||
if self.line_buffer is None:
|
if self.line_buffer is None:
|
||||||
# read first non-empty line of file
|
# read first non-empty line of file
|
||||||
while t1 == "" :
|
while t1 == "" :
|
||||||
t1 = self.file.__next__().strip()
|
t1 = next(self.file).strip()
|
||||||
t2 = self.file.__next__().strip()
|
t2 = next(self.file).strip()
|
||||||
else:
|
else:
|
||||||
# use t2 from previous iteration as new t1
|
# use t2 from previous iteration as new t1
|
||||||
t1 = self.line_buffer
|
t1 = self.line_buffer
|
||||||
t2 = self.file.__next__().strip()
|
t2 = next(self.file).strip()
|
||||||
# skip empty rows that are used for separating documents and keep track of current doc id
|
# skip empty rows that are used for separating documents and keep track of current doc id
|
||||||
while t2 == "" or t1 == "":
|
while t2 == "" or t1 == "":
|
||||||
t1 = self.file.__next__().strip()
|
t1 = next(self.file).strip()
|
||||||
t2 = self.file.__next__().strip()
|
t2 = next(self.file).strip()
|
||||||
self.current_doc = self.current_doc+1
|
self.current_doc = self.current_doc+1
|
||||||
self.line_buffer = t2
|
self.line_buffer = t2
|
||||||
|
|
||||||
@@ -228,15 +225,15 @@ class BERTDataset(Dataset):
|
|||||||
def get_next_line(self):
|
def get_next_line(self):
|
||||||
""" Gets next line of random_file and starts over when reaching end of file"""
|
""" Gets next line of random_file and starts over when reaching end of file"""
|
||||||
try:
|
try:
|
||||||
line = self.random_file.__next__().strip()
|
line = next(self.random_file).strip()
|
||||||
#keep track of which document we are currently looking at to later avoid having the same doc as t1
|
#keep track of which document we are currently looking at to later avoid having the same doc as t1
|
||||||
if line == "":
|
if line == "":
|
||||||
self.current_random_doc = self.current_random_doc + 1
|
self.current_random_doc = self.current_random_doc + 1
|
||||||
line = self.random_file.__next__().strip()
|
line = next(self.random_file).strip()
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
self.random_file.close()
|
self.random_file.close()
|
||||||
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
|
self.random_file = open(self.corpus_path, "r", encoding=self.encoding)
|
||||||
line = self.random_file.__next__().strip()
|
line = next(self.random_file).strip()
|
||||||
return line
|
return line
|
||||||
|
|
||||||
|
|
||||||
@@ -425,6 +422,7 @@ def main():
|
|||||||
help="The output directory where the model checkpoints will be written.")
|
help="The output directory where the model checkpoints will be written.")
|
||||||
|
|
||||||
## Other parameters
|
## Other parameters
|
||||||
|
parser.add_argument("--do_lower_case", action='store_true', help="Set this flag if you are using an uncased model.")
|
||||||
parser.add_argument("--max_seq_length",
|
parser.add_argument("--max_seq_length",
|
||||||
default=128,
|
default=128,
|
||||||
type=int,
|
type=int,
|
||||||
@@ -513,7 +511,8 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
@@ -579,7 +578,7 @@ def main():
|
|||||||
if args.local_rank == -1:
|
if args.local_rank == -1:
|
||||||
train_sampler = RandomSampler(train_dataset)
|
train_sampler = RandomSampler(train_dataset)
|
||||||
else:
|
else:
|
||||||
#TODO: check if this works with current data generator from disk that relies on file.__next__
|
#TODO: check if this works with current data generator from disk that relies on next(file)
|
||||||
# (it doesn't return item back by index)
|
# (it doesn't return item back by index)
|
||||||
train_sampler = DistributedSampler(train_dataset)
|
train_sampler = DistributedSampler(train_dataset)
|
||||||
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)
|
||||||
|
|||||||
@@ -15,29 +15,36 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Run BERT on SQuAD."""
|
"""Run BERT on SQuAD."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import collections
|
import collections
|
||||||
import logging
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
import pickle
|
import sys
|
||||||
from tqdm import tqdm, trange
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
|
TensorDataset)
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import whitespace_tokenize, BasicTokenizer, BertTokenizer
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
from pytorch_pretrained_bert.modeling import BertForQuestionAnswering
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.tokenization import (BasicTokenizer,
|
||||||
|
BertTokenizer,
|
||||||
|
whitespace_tokenize)
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -784,7 +791,8 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory () already exists and is not empty.")
|
raise ValueError("Output directory () already exists and is not empty.")
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
@@ -798,7 +806,7 @@ def main():
|
|||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
model = BertForQuestionAnswering.from_pretrained(args.bert_model,
|
||||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
|
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)))
|
||||||
|
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|||||||
@@ -15,22 +15,25 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""BERT finetuning runner."""
|
"""BERT finetuning runner."""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import csv
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import argparse
|
|
||||||
import random
|
import random
|
||||||
from tqdm import tqdm, trange
|
import sys
|
||||||
import csv
|
from io import open
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
|
||||||
|
TensorDataset)
|
||||||
from torch.utils.data.distributed import DistributedSampler
|
from torch.utils.data.distributed import DistributedSampler
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
from pytorch_pretrained_bert.modeling import BertForMultipleChoice
|
||||||
from pytorch_pretrained_bert.optimization import BertAdam
|
from pytorch_pretrained_bert.optimization import BertAdam
|
||||||
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
|
from pytorch_pretrained_bert.tokenization import BertTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@@ -65,17 +68,17 @@ class SwagExample(object):
|
|||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
l = [
|
l = [
|
||||||
f"swag_id: {self.swag_id}",
|
"swag_id: {}".format(self.swag_id),
|
||||||
f"context_sentence: {self.context_sentence}",
|
"context_sentence: {}".format(self.context_sentence),
|
||||||
f"start_ending: {self.start_ending}",
|
"start_ending: {}".format(self.start_ending),
|
||||||
f"ending_0: {self.endings[0]}",
|
"ending_0: {}".format(self.endings[0]),
|
||||||
f"ending_1: {self.endings[1]}",
|
"ending_1: {}".format(self.endings[1]),
|
||||||
f"ending_2: {self.endings[2]}",
|
"ending_2: {}".format(self.endings[2]),
|
||||||
f"ending_3: {self.endings[3]}",
|
"ending_3: {}".format(self.endings[3]),
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.label is not None:
|
if self.label is not None:
|
||||||
l.append(f"label: {self.label}")
|
l.append("label: {}".format(self.label))
|
||||||
|
|
||||||
return ", ".join(l)
|
return ", ".join(l)
|
||||||
|
|
||||||
@@ -102,7 +105,11 @@ class InputFeatures(object):
|
|||||||
def read_swag_examples(input_file, is_training):
|
def read_swag_examples(input_file, is_training):
|
||||||
with open(input_file, 'r', encoding='utf-8') as f:
|
with open(input_file, 'r', encoding='utf-8') as f:
|
||||||
reader = csv.reader(f)
|
reader = csv.reader(f)
|
||||||
lines = list(reader)
|
lines = []
|
||||||
|
for line in reader:
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
line = list(unicode(cell, 'utf-8') for cell in line)
|
||||||
|
lines.append(line)
|
||||||
|
|
||||||
if is_training and lines[0][-1] != 'label':
|
if is_training and lines[0][-1] != 'label':
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -184,15 +191,15 @@ def convert_examples_to_features(examples, tokenizer, max_seq_length,
|
|||||||
label = example.label
|
label = example.label
|
||||||
if example_index < 5:
|
if example_index < 5:
|
||||||
logger.info("*** Example ***")
|
logger.info("*** Example ***")
|
||||||
logger.info(f"swag_id: {example.swag_id}")
|
logger.info("swag_id: {}".format(example.swag_id))
|
||||||
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
for choice_idx, (tokens, input_ids, input_mask, segment_ids) in enumerate(choices_features):
|
||||||
logger.info(f"choice: {choice_idx}")
|
logger.info("choice: {}".format(choice_idx))
|
||||||
logger.info(f"tokens: {' '.join(tokens)}")
|
logger.info("tokens: {}".format(' '.join(tokens)))
|
||||||
logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
|
logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
|
||||||
logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
|
logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
|
||||||
logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
|
logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))
|
||||||
if is_training:
|
if is_training:
|
||||||
logger.info(f"label: {label}")
|
logger.info("label: {}".format(label))
|
||||||
|
|
||||||
features.append(
|
features.append(
|
||||||
InputFeatures(
|
InputFeatures(
|
||||||
@@ -349,7 +356,8 @@ def main():
|
|||||||
|
|
||||||
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
|
||||||
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
|
||||||
os.makedirs(args.output_dir, exist_ok=True)
|
if not os.path.exists(args.output_dir):
|
||||||
|
os.makedirs(args.output_dir)
|
||||||
|
|
||||||
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
|
||||||
|
|
||||||
@@ -362,7 +370,7 @@ def main():
|
|||||||
|
|
||||||
# Prepare model
|
# Prepare model
|
||||||
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
model = BertForMultipleChoice.from_pretrained(args.bert_model,
|
||||||
cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank),
|
cache_dir=os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank)),
|
||||||
num_choices=4)
|
num_choices=4)
|
||||||
if args.fp16:
|
if args.fp16:
|
||||||
model.half()
|
model.half()
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ def main():
|
|||||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ModuleNotFoundError:
|
except ImportError:
|
||||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
"In that case, it requires TensorFlow to be installed. Please see "
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
@@ -43,7 +43,7 @@ def main():
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ModuleNotFoundError:
|
except ImportError:
|
||||||
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
"In that case, it requires TensorFlow to be installed. Please see "
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
|
|||||||
@@ -14,14 +14,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert OpenAI GPT checkpoint."""
|
"""Convert OpenAI GPT checkpoint."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling_openai import load_tf_weights_in_openai_gpt, OpenAIGPTConfig, OpenAIGPTModel, CONFIG_NAME, WEIGHTS_NAME
|
from pytorch_pretrained_bert.modeling_openai import (CONFIG_NAME, WEIGHTS_NAME,
|
||||||
|
OpenAIGPTConfig,
|
||||||
|
OpenAIGPTModel,
|
||||||
|
load_tf_weights_in_openai_gpt)
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_config_file, pytorch_dump_folder_path):
|
||||||
# Construct model
|
# Construct model
|
||||||
|
|||||||
@@ -14,25 +14,31 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Convert Transformer XL checkpoint and datasets."""
|
"""Convert Transformer XL checkpoint and datasets."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import argparse
|
from io import open
|
||||||
import pickle
|
|
||||||
|
|
||||||
import tensorflow as tf
|
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from pytorch_pretrained_bert.modeling_transfo_xl import TransfoXLConfig, TransfoXLModel, CONFIG_NAME, WEIGHTS_NAME, load_tf_weights_in_transfo_xl
|
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
||||||
from pytorch_pretrained_bert.tokenization_transfo_xl import VOCAB_NAME, CORPUS_NAME
|
from pytorch_pretrained_bert.modeling_transfo_xl import (CONFIG_NAME,
|
||||||
|
WEIGHTS_NAME,
|
||||||
|
TransfoXLConfig,
|
||||||
|
TransfoXLModel,
|
||||||
|
load_tf_weights_in_transfo_xl)
|
||||||
|
from pytorch_pretrained_bert.tokenization_transfo_xl import (CORPUS_NAME,
|
||||||
|
VOCAB_NAME)
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
|
||||||
# We do this to be able to load the python 2 datasets pickles
|
# We do this to be able to load the python 2 datasets pickles
|
||||||
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
# See e.g. https://stackoverflow.com/questions/2121874/python-pickling-after-changing-a-modules-directory/2121918#2121918
|
||||||
import pytorch_pretrained_bert.tokenization_transfo_xl as data_utils
|
|
||||||
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
data_utils.Vocab = data_utils.TransfoXLTokenizer
|
||||||
data_utils.Corpus = data_utils.TransfoXLCorpus
|
data_utils.Corpus = data_utils.TransfoXLCorpus
|
||||||
sys.modules['data_utils'] = data_utils
|
sys.modules['data_utils'] = data_utils
|
||||||
|
|||||||
@@ -3,31 +3,39 @@ Utilities for working with the local dataset cache.
|
|||||||
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
|
||||||
Copyright by the AllenNLP authors.
|
Copyright by the AllenNLP authors.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import (absolute_import, division, print_function, unicode_literals)
|
||||||
|
|
||||||
import os
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import json
|
|
||||||
from urllib.parse import urlparse
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Tuple, Union, IO, Callable, Set
|
|
||||||
from hashlib import sha256
|
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
|
from hashlib import sha256
|
||||||
from tqdm import tqdm
|
from io import open
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
import requests
|
import requests
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
try:
|
||||||
|
from urllib.parse import urlparse
|
||||||
|
except ImportError:
|
||||||
|
from urlparse import urlparse
|
||||||
|
|
||||||
|
try:
|
||||||
|
from pathlib import Path
|
||||||
|
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||||
|
Path.home() / '.pytorch_pretrained_bert'))
|
||||||
|
except ImportError:
|
||||||
|
PYTORCH_PRETRAINED_BERT_CACHE = os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
||||||
|
os.path.join(os.path.expanduser("~"), '.pytorch_pretrained_bert'))
|
||||||
|
|
||||||
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
|
||||||
|
|
||||||
PYTORCH_PRETRAINED_BERT_CACHE = Path(os.getenv('PYTORCH_PRETRAINED_BERT_CACHE',
|
|
||||||
Path.home() / '.pytorch_pretrained_bert'))
|
|
||||||
|
|
||||||
|
def url_to_filename(url, etag=None):
|
||||||
def url_to_filename(url: str, etag: str = None) -> str:
|
|
||||||
"""
|
"""
|
||||||
Convert `url` into a hashed filename in a repeatable way.
|
Convert `url` into a hashed filename in a repeatable way.
|
||||||
If `etag` is specified, append its hash to the url's, delimited
|
If `etag` is specified, append its hash to the url's, delimited
|
||||||
@@ -45,25 +53,23 @@ def url_to_filename(url: str, etag: str = None) -> str:
|
|||||||
return filename
|
return filename
|
||||||
|
|
||||||
|
|
||||||
def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[str, str]:
|
def filename_to_url(filename, cache_dir=None):
|
||||||
"""
|
"""
|
||||||
Return the url and etag (which may be ``None``) stored for `filename`.
|
Return the url and etag (which may be ``None``) stored for `filename`.
|
||||||
Raise ``FileNotFoundError`` if `filename` or its stored metadata do not exist.
|
Raise ``EnvironmentError`` if `filename` or its stored metadata do not exist.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
cache_path = os.path.join(cache_dir, filename)
|
cache_path = os.path.join(cache_dir, filename)
|
||||||
if not os.path.exists(cache_path):
|
if not os.path.exists(cache_path):
|
||||||
raise FileNotFoundError("file {} not found".format(cache_path))
|
raise EnvironmentError("file {} not found".format(cache_path))
|
||||||
|
|
||||||
meta_path = cache_path + '.json'
|
meta_path = cache_path + '.json'
|
||||||
if not os.path.exists(meta_path):
|
if not os.path.exists(meta_path):
|
||||||
raise FileNotFoundError("file {} not found".format(meta_path))
|
raise EnvironmentError("file {} not found".format(meta_path))
|
||||||
|
|
||||||
with open(meta_path) as meta_file:
|
with open(meta_path, encoding="utf-8") as meta_file:
|
||||||
metadata = json.load(meta_file)
|
metadata = json.load(meta_file)
|
||||||
url = metadata['url']
|
url = metadata['url']
|
||||||
etag = metadata['etag']
|
etag = metadata['etag']
|
||||||
@@ -71,7 +77,7 @@ def filename_to_url(filename: str, cache_dir: Union[str, Path] = None) -> Tuple[
|
|||||||
return url, etag
|
return url, etag
|
||||||
|
|
||||||
|
|
||||||
def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] = None) -> str:
|
def cached_path(url_or_filename, cache_dir=None):
|
||||||
"""
|
"""
|
||||||
Given something that might be a URL (or might be a local path),
|
Given something that might be a URL (or might be a local path),
|
||||||
determine which. If it's a URL, download the file and cache it, and
|
determine which. If it's a URL, download the file and cache it, and
|
||||||
@@ -80,10 +86,6 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
|
|||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
if isinstance(url_or_filename, Path):
|
|
||||||
url_or_filename = str(url_or_filename)
|
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
parsed = urlparse(url_or_filename)
|
parsed = urlparse(url_or_filename)
|
||||||
|
|
||||||
@@ -95,13 +97,13 @@ def cached_path(url_or_filename: Union[str, Path], cache_dir: Union[str, Path] =
|
|||||||
return url_or_filename
|
return url_or_filename
|
||||||
elif parsed.scheme == '':
|
elif parsed.scheme == '':
|
||||||
# File, but it doesn't exist.
|
# File, but it doesn't exist.
|
||||||
raise FileNotFoundError("file {} not found".format(url_or_filename))
|
raise EnvironmentError("file {} not found".format(url_or_filename))
|
||||||
else:
|
else:
|
||||||
# Something unknown
|
# Something unknown
|
||||||
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
|
||||||
|
|
||||||
|
|
||||||
def split_s3_path(url: str) -> Tuple[str, str]:
|
def split_s3_path(url):
|
||||||
"""Split a full s3 path into the bucket name and path."""
|
"""Split a full s3 path into the bucket name and path."""
|
||||||
parsed = urlparse(url)
|
parsed = urlparse(url)
|
||||||
if not parsed.netloc or not parsed.path:
|
if not parsed.netloc or not parsed.path:
|
||||||
@@ -114,19 +116,19 @@ def split_s3_path(url: str) -> Tuple[str, str]:
|
|||||||
return bucket_name, s3_path
|
return bucket_name, s3_path
|
||||||
|
|
||||||
|
|
||||||
def s3_request(func: Callable):
|
def s3_request(func):
|
||||||
"""
|
"""
|
||||||
Wrapper function for s3 requests in order to create more helpful error
|
Wrapper function for s3 requests in order to create more helpful error
|
||||||
messages.
|
messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
def wrapper(url: str, *args, **kwargs):
|
def wrapper(url, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
return func(url, *args, **kwargs)
|
return func(url, *args, **kwargs)
|
||||||
except ClientError as exc:
|
except ClientError as exc:
|
||||||
if int(exc.response["Error"]["Code"]) == 404:
|
if int(exc.response["Error"]["Code"]) == 404:
|
||||||
raise FileNotFoundError("file {} not found".format(url))
|
raise EnvironmentError("file {} not found".format(url))
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
@@ -134,7 +136,7 @@ def s3_request(func: Callable):
|
|||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_etag(url: str) -> Optional[str]:
|
def s3_etag(url):
|
||||||
"""Check ETag on S3 object."""
|
"""Check ETag on S3 object."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3")
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
@@ -143,14 +145,14 @@ def s3_etag(url: str) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
@s3_request
|
@s3_request
|
||||||
def s3_get(url: str, temp_file: IO) -> None:
|
def s3_get(url, temp_file):
|
||||||
"""Pull a file directly from S3."""
|
"""Pull a file directly from S3."""
|
||||||
s3_resource = boto3.resource("s3")
|
s3_resource = boto3.resource("s3")
|
||||||
bucket_name, s3_path = split_s3_path(url)
|
bucket_name, s3_path = split_s3_path(url)
|
||||||
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
s3_resource.Bucket(bucket_name).download_fileobj(s3_path, temp_file)
|
||||||
|
|
||||||
|
|
||||||
def http_get(url: str, temp_file: IO) -> None:
|
def http_get(url, temp_file):
|
||||||
req = requests.get(url, stream=True)
|
req = requests.get(url, stream=True)
|
||||||
content_length = req.headers.get('Content-Length')
|
content_length = req.headers.get('Content-Length')
|
||||||
total = int(content_length) if content_length is not None else None
|
total = int(content_length) if content_length is not None else None
|
||||||
@@ -162,17 +164,16 @@ def http_get(url: str, temp_file: IO) -> None:
|
|||||||
progress.close()
|
progress.close()
|
||||||
|
|
||||||
|
|
||||||
def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
def get_from_cache(url, cache_dir=None):
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding dataset in the local cache.
|
Given a URL, look for the corresponding dataset in the local cache.
|
||||||
If it's not there, download it. Then return the path to the cached file.
|
If it's not there, download it. Then return the path to the cached file.
|
||||||
"""
|
"""
|
||||||
if cache_dir is None:
|
if cache_dir is None:
|
||||||
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
cache_dir = PYTORCH_PRETRAINED_BERT_CACHE
|
||||||
if isinstance(cache_dir, Path):
|
|
||||||
cache_dir = str(cache_dir)
|
|
||||||
|
|
||||||
os.makedirs(cache_dir, exist_ok=True)
|
if not os.path.exists(cache_dir):
|
||||||
|
os.makedirs(cache_dir)
|
||||||
|
|
||||||
# Get eTag to add to filename, if it exists.
|
# Get eTag to add to filename, if it exists.
|
||||||
if url.startswith("s3://"):
|
if url.startswith("s3://"):
|
||||||
@@ -213,7 +214,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
|||||||
logger.info("creating metadata file for %s", cache_path)
|
logger.info("creating metadata file for %s", cache_path)
|
||||||
meta = {'url': url, 'etag': etag}
|
meta = {'url': url, 'etag': etag}
|
||||||
meta_path = cache_path + '.json'
|
meta_path = cache_path + '.json'
|
||||||
with open(meta_path, 'w') as meta_file:
|
with open(meta_path, 'w', encoding="utf-8") as meta_file:
|
||||||
json.dump(meta, meta_file)
|
json.dump(meta, meta_file)
|
||||||
|
|
||||||
logger.info("removing temp file %s", temp_file.name)
|
logger.info("removing temp file %s", temp_file.name)
|
||||||
@@ -221,7 +222,7 @@ def get_from_cache(url: str, cache_dir: Union[str, Path] = None) -> str:
|
|||||||
return cache_path
|
return cache_path
|
||||||
|
|
||||||
|
|
||||||
def read_set_from_file(filename: str) -> Set[str]:
|
def read_set_from_file(filename):
|
||||||
'''
|
'''
|
||||||
Extract a de-duped collection (set) of text from a file.
|
Extract a de-duped collection (set) of text from a file.
|
||||||
Expected file format is one item per line.
|
Expected file format is one item per line.
|
||||||
@@ -233,7 +234,7 @@ def read_set_from_file(filename: str) -> Set[str]:
|
|||||||
return collection
|
return collection
|
||||||
|
|
||||||
|
|
||||||
def get_file_extension(path: str, dot=True, lower: bool = True):
|
def get_file_extension(path, dot=True, lower=True):
|
||||||
ext = os.path.splitext(path)[1]
|
ext = os.path.splitext(path)[1]
|
||||||
ext = ext if dot else ext[1:]
|
ext = ext if dot else ext[1:]
|
||||||
return ext.lower() if lower else ext
|
return ext.lower() if lower else ext
|
||||||
|
|||||||
@@ -15,18 +15,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""PyTorch BERT model."""
|
"""PyTorch BERT model."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import os
|
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
import math
|
|
||||||
import logging
|
import logging
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
@@ -56,7 +56,7 @@ def load_tf_weights_in_bert(model, tf_checkpoint_path):
|
|||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ModuleNotFoundError:
|
except ImportError:
|
||||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise
|
raise
|
||||||
@@ -164,7 +164,8 @@ class BertConfig(object):
|
|||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str):
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||||
json_config = json.loads(reader.read())
|
json_config = json.loads(reader.read())
|
||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
@@ -343,8 +344,10 @@ class BertIntermediate(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertIntermediate, self).__init__()
|
super(BertIntermediate, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||||
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
|
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
self.intermediate_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.intermediate_act_fn = config.hidden_act
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = self.dense(hidden_states)
|
hidden_states = self.dense(hidden_states)
|
||||||
@@ -416,8 +419,10 @@ class BertPredictionHeadTransform(nn.Module):
|
|||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
super(BertPredictionHeadTransform, self).__init__()
|
super(BertPredictionHeadTransform, self).__init__()
|
||||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||||
self.transform_act_fn = ACT2FN[config.hidden_act] \
|
if isinstance(config.hidden_act, str) or (sys.version_info[0] == 2 and isinstance(config.hidden_act, unicode)):
|
||||||
if isinstance(config.hidden_act, str) else config.hidden_act
|
self.transform_act_fn = ACT2FN[config.hidden_act]
|
||||||
|
else:
|
||||||
|
self.transform_act_fn = config.hidden_act
|
||||||
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
self.LayerNorm = BertLayerNorm(config.hidden_size, eps=1e-12)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
@@ -542,7 +547,7 @@ class BertPreTrainedModel(nn.Module):
|
|||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
import tarfile
|
import tarfile
|
||||||
import tempfile
|
import tempfile
|
||||||
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -160,7 +162,8 @@ class OpenAIGPTConfig(object):
|
|||||||
initializer_range: The sttdev of the truncated_normal_initializer for
|
initializer_range: The sttdev of the truncated_normal_initializer for
|
||||||
initializing all weight matrices.
|
initializing all weight matrices.
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str):
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
with open(vocab_size_or_config_json_file, "r", encoding="utf-8") as reader:
|
||||||
json_config = json.loads(reader.read())
|
json_config = json.loads(reader.read())
|
||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
@@ -442,7 +445,7 @@ class OpenAIGPTPreTrainedModel(nn.Module):
|
|||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
@@ -641,7 +644,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
|||||||
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
hidden_states = inputs_embeds + position_embeds + token_type_embeds
|
||||||
for block in self.h:
|
for block in self.h:
|
||||||
hidden_states = block(hidden_states)
|
hidden_states = block(hidden_states)
|
||||||
return hidden_states.view(*input_shape, hidden_states.size(-1))
|
output_shape = input_shape + (hidden_states.size(-1),)
|
||||||
|
return hidden_states.view(*output_shape)
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||||
|
|||||||
@@ -27,6 +27,8 @@ import tarfile
|
|||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
import collections
|
import collections
|
||||||
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
@@ -124,7 +126,7 @@ def load_tf_weights_in_transfo_xl(model, config, tf_path):
|
|||||||
try:
|
try:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
except ModuleNotFoundError:
|
except ImportError:
|
||||||
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
print("Loading a TensorFlow models in PyTorch, requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise
|
raise
|
||||||
@@ -239,7 +241,8 @@ class TransfoXLConfig(object):
|
|||||||
proj_init_std: parameters initialized by N(0, init_std)
|
proj_init_std: parameters initialized by N(0, init_std)
|
||||||
init_std: parameters initialized by N(0, init_std)
|
init_std: parameters initialized by N(0, init_std)
|
||||||
"""
|
"""
|
||||||
if isinstance(vocab_size_or_config_json_file, str):
|
if isinstance(vocab_size_or_config_json_file, str) or (sys.version_info[0] == 2
|
||||||
|
and isinstance(vocab_size_or_config_json_file, unicode)):
|
||||||
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
with open(vocab_size_or_config_json_file, "r", encoding='utf-8') as reader:
|
||||||
json_config = json.loads(reader.read())
|
json_config = json.loads(reader.read())
|
||||||
for key, value in json_config.items():
|
for key, value in json_config.items():
|
||||||
@@ -503,11 +506,12 @@ class RelMultiHeadAttn(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
def _rel_shift(self, x, zero_triu=False):
|
def _rel_shift(self, x, zero_triu=False):
|
||||||
zero_pad = torch.zeros((x.size(0), 1, *x.size()[2:]),
|
zero_pad_shape = (x.size(0), 1) + x.size()[2:]
|
||||||
device=x.device, dtype=x.dtype)
|
zero_pad = torch.zeros(zero_pad_shape, device=x.device, dtype=x.dtype)
|
||||||
x_padded = torch.cat([zero_pad, x], dim=1)
|
x_padded = torch.cat([zero_pad, x], dim=1)
|
||||||
|
|
||||||
x_padded = x_padded.view(x.size(1) + 1, x.size(0), *x.size()[2:])
|
x_padded_shape = (x.size(1) + 1, x.size(0)) + x.size()[2:]
|
||||||
|
x_padded = x_padded.view(*x_padded_shape)
|
||||||
|
|
||||||
x = x_padded[1:].view_as(x)
|
x = x_padded[1:].view_as(x)
|
||||||
|
|
||||||
@@ -797,7 +801,8 @@ class AdaptiveEmbedding(nn.Module):
|
|||||||
|
|
||||||
emb_flat.index_copy_(0, indices_i, emb_i)
|
emb_flat.index_copy_(0, indices_i, emb_i)
|
||||||
|
|
||||||
embed = emb_flat.view(*inp.size(), self.d_proj)
|
embed_shape = inp.size() + (self.d_proj,)
|
||||||
|
embed = emb_flat.view(embed_shape)
|
||||||
|
|
||||||
embed.mul_(self.emb_scale)
|
embed.mul_(self.emb_scale)
|
||||||
|
|
||||||
@@ -905,7 +910,7 @@ class TransfoXLPreTrainedModel(nn.Module):
|
|||||||
try:
|
try:
|
||||||
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
resolved_archive_file = cached_path(archive_file, cache_dir=cache_dir)
|
||||||
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
resolved_config_file = cached_path(config_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
|
|||||||
@@ -14,14 +14,13 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes."""
|
"""Tokenization classes."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||||
from __future__ import division
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import unicodedata
|
|
||||||
import os
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
import unicodedata
|
||||||
|
from io import open
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
|
||||||
@@ -129,7 +128,7 @@ class BertTokenizer(object):
|
|||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find any file "
|
"We assumed '{}' was a path or url but couldn't find any file "
|
||||||
|
|||||||
@@ -13,11 +13,17 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
"""Tokenization classes for OpenAI GPT."""
|
"""Tokenization classes for OpenAI GPT."""
|
||||||
|
from __future__ import (absolute_import, division, print_function,
|
||||||
|
unicode_literals)
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import json
|
import sys
|
||||||
|
from io import open
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import logging
|
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
|
||||||
@@ -82,7 +88,7 @@ class OpenAIGPTTokenizer(object):
|
|||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||||
@@ -119,7 +125,7 @@ class OpenAIGPTTokenizer(object):
|
|||||||
self.max_len = max_len if max_len is not None else int(1e12)
|
self.max_len = max_len if max_len is not None else int(1e12)
|
||||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||||
self.fix_text = ftfy.fix_text
|
self.fix_text = ftfy.fix_text
|
||||||
self.encoder = json.load(open(vocab_file))
|
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||||
merges = [tuple(merge.split()) for merge in merges]
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
@@ -196,7 +202,7 @@ class OpenAIGPTTokenizer(object):
|
|||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
"""Converts a sequence of tokens into ids using the vocab."""
|
"""Converts a sequence of tokens into ids using the vocab."""
|
||||||
ids = []
|
ids = []
|
||||||
if isinstance(tokens, str):
|
if isinstance(tokens, str) or (sys.version_info[0] == 2 and isinstance(tokens, unicode)):
|
||||||
if tokens in self.special_tokens:
|
if tokens in self.special_tokens:
|
||||||
return self.special_tokens[tokens]
|
return self.special_tokens[tokens]
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -16,16 +16,27 @@
|
|||||||
""" Tokenization classes for Transformer XL model.
|
""" Tokenization classes for Transformer XL model.
|
||||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||||
"""
|
"""
|
||||||
|
from __future__ import (absolute_import, division, print_function,
|
||||||
|
unicode_literals)
|
||||||
|
|
||||||
import os
|
|
||||||
import glob
|
import glob
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import os
|
||||||
import torch
|
import sys
|
||||||
from collections import Counter, OrderedDict
|
from collections import Counter, OrderedDict
|
||||||
|
from io import open
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
from .file_utils import cached_path
|
from .file_utils import cached_path
|
||||||
|
|
||||||
|
if sys.version_info[0] == 2:
|
||||||
|
import cPickle as pickle
|
||||||
|
else:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||||
@@ -55,7 +66,7 @@ class TransfoXLTokenizer(object):
|
|||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||||
@@ -422,7 +433,7 @@ class TransfoXLCorpus(object):
|
|||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||||
except FileNotFoundError:
|
except EnvironmentError:
|
||||||
logger.error(
|
logger.error(
|
||||||
"Model name '{}' was not found in model name list ({}). "
|
"Model name '{}' was not found in model name list ({}). "
|
||||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||||
|
|||||||
3
setup.py
3
setup.py
@@ -33,6 +33,7 @@ To create the package for pypi.
|
|||||||
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
7. Copy the release notes from RELEASE.md to the tag in github once everything is looking hunky-dory.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
from io import open
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
setup(
|
setup(
|
||||||
@@ -58,7 +59,7 @@ setup(
|
|||||||
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
|
"pytorch_pretrained_bert=pytorch_pretrained_bert.__main__:main",
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
python_requires='>=3.5.0',
|
# python_requires='>=3.5.0',
|
||||||
tests_require=['pytest'],
|
tests_require=['pytest'],
|
||||||
classifiers=[
|
classifiers=[
|
||||||
'Intended Audience :: Science/Research',
|
'Intended Audience :: Science/Research',
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
from io import open
|
||||||
|
|
||||||
from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
|
from pytorch_pretrained_bert.tokenization import (BertTokenizer, BasicTokenizer, WordpieceTokenizer,
|
||||||
_is_whitespace, _is_control, _is_punctuation)
|
_is_whitespace, _is_control, _is_punctuation)
|
||||||
@@ -30,7 +31,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ","
|
"##ing", ","
|
||||||
]
|
]
|
||||||
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
|
|
||||||
vocab_file = vocab_writer.name
|
vocab_file = vocab_writer.name
|
||||||
@@ -49,7 +50,7 @@ class TokenizationTest(unittest.TestCase):
|
|||||||
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
"[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn",
|
||||||
"##ing", ","
|
"##ing", ","
|
||||||
]
|
]
|
||||||
with open("/tmp/bert_tokenizer_test.txt", "w") as vocab_writer:
|
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
|
||||||
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
|
||||||
vocab_file = vocab_writer.name
|
vocab_file = vocab_writer.name
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user