improved corpus and tokenization conversion - added evaluation script
This commit is contained in:
@@ -14,15 +14,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Tokenization classes for Transformer XL model.
|
||||
Directly adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
Adapted from https://github.com/kimiyoung/transformer-xl.
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
from tqdm import tqdm
|
||||
import glob
|
||||
import logging
|
||||
import pickle
|
||||
import torch
|
||||
from collections import Counter, OrderedDict
|
||||
|
||||
from .file_utils import cached_path
|
||||
@@ -30,16 +29,14 @@ from .file_utils import cached_path
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_VOCAB_ARCHIVE_MAP = {
|
||||
'transfo-xl': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-vocab.json",
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-vocab.bin",
|
||||
}
|
||||
PRETRAINED_MERGES_ARCHIVE_MAP = {
|
||||
'openai-gpt': "https://s3.amazonaws.com/models.huggingface.co/bert/openai-gpt-merges.txt",
|
||||
VOCAB_NAME = 'vocab.bin'
|
||||
|
||||
PRETRAINED_CORPUS_ARCHIVE_MAP = {
|
||||
'transfo-xl-wt103': "https://s3.amazonaws.com/models.huggingface.co/bert/transfo-xl-wt103-corpus.bin",
|
||||
}
|
||||
PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
'openai-gpt': 512,
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
CORPUS_NAME = 'corpus.bin'
|
||||
|
||||
class TransfoXLTokenizer(object):
|
||||
"""
|
||||
@@ -49,43 +46,36 @@ class TransfoXLTokenizer(object):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a TransfoXLTokenizer.
|
||||
Download and cache the vocabulary if needed.
|
||||
The TransfoXLTokenizer.
|
||||
"""
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
resolved_merges_file = cached_path(merges_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file, merges_file))
|
||||
vocab_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
if resolved_vocab_file == vocab_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
logger.info("loading merges file {}".format(merges_file))
|
||||
else:
|
||||
logger.info("loading vocabulary file {} from cache at {}".format(
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
tokenizer = cls(*inputs, **kwargs)
|
||||
vocab_dict = torch.load(resolved_vocab_file)
|
||||
for key, value in vocab_dict.items():
|
||||
tokenizer.__dict__[key] = value
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, special=[], min_freq=0, max_size=None, lower_case=True,
|
||||
@@ -418,10 +408,53 @@ class LMMultiFileIterator(LMShuffledIterator):
|
||||
yield batch
|
||||
|
||||
|
||||
class Corpus(object):
|
||||
def __init__(self, path, dataset, *args, **kwargs):
|
||||
class TransfoXLCorpus(object):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a pre-processed corpus.
|
||||
"""
|
||||
vocab = TransfoXLTokenizer.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
|
||||
if pretrained_model_name_or_path in PRETRAINED_CORPUS_ARCHIVE_MAP:
|
||||
corpus_file = PRETRAINED_CORPUS_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
corpus_file = os.path.join(pretrained_model_name_or_path, CORPUS_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_corpus_file = cached_path(corpus_file, cache_dir=cache_dir)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find files {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
pretrained_model_name_or_path,
|
||||
corpus_file))
|
||||
return None
|
||||
if resolved_corpus_file == corpus_file:
|
||||
logger.info("loading corpus file {}".format(corpus_file))
|
||||
else:
|
||||
logger.info("loading corpus file {} from cache at {}".format(
|
||||
corpus_file, resolved_corpus_file))
|
||||
|
||||
# Instantiate tokenizer.
|
||||
corpus = cls(*inputs, **kwargs)
|
||||
corpus_dict = torch.load(resolved_corpus_file)
|
||||
for key, value in corpus_dict.items():
|
||||
corpus.__dict__[key] = value
|
||||
corpus.vocab = vocab
|
||||
return corpus
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.vocab = TransfoXLTokenizer(*args, **kwargs)
|
||||
self.dataset = None
|
||||
self.train = None
|
||||
self.valid = None
|
||||
self.test = None
|
||||
|
||||
def build_corpus(self, path, dataset):
|
||||
self.dataset = dataset
|
||||
self.vocab = Vocab(*args, **kwargs)
|
||||
|
||||
if self.dataset in ['ptb', 'wt2', 'enwik8', 'text8']:
|
||||
self.vocab.count_file(os.path.join(path, 'train.txt'))
|
||||
@@ -443,20 +476,20 @@ class Corpus(object):
|
||||
os.path.join(path, 'train.txt'), ordered=True)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True)
|
||||
elif self.dataset in ['enwik8', 'text8']:
|
||||
self.train = self.vocab.encode_file(
|
||||
os.path.join(path, 'train.txt'), ordered=True, add_eos=False)
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=True, add_eos=False)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=True, add_eos=False)
|
||||
elif self.dataset == 'lm1b':
|
||||
self.train = train_paths
|
||||
self.valid = self.vocab.encode_file(
|
||||
os.path.join(path, 'valid.txt'), ordered=False, add_double_eos=True)
|
||||
self.test = self.vocab.encode_file(
|
||||
self.test = self.vocab.encode_file(
|
||||
os.path.join(path, 'test.txt'), ordered=False, add_double_eos=True)
|
||||
|
||||
def get_iterator(self, split, *args, **kwargs):
|
||||
@@ -502,7 +535,7 @@ def get_lm_corpus(datadir, dataset):
|
||||
elif dataset in ['enwik8', 'text8']:
|
||||
pass
|
||||
|
||||
corpus = Corpus(datadir, dataset, **kwargs)
|
||||
corpus = TransfoXLCorpus(datadir, dataset, **kwargs)
|
||||
torch.save(corpus, fn)
|
||||
|
||||
return corpus
|
||||
|
||||
Reference in New Issue
Block a user