added tests + fixed losses

This commit is contained in:
thomwolf
2019-01-08 16:24:23 +01:00
parent eed51c5bdf
commit 3cf12b235a
4 changed files with 484 additions and 225 deletions

View File

@@ -67,19 +67,17 @@ class OpenAIGPTTokenizer(object):
mostly a wrapper for a public python bpe tokenizer
"""
@classmethod
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
"""
Instantiate a PreTrainedBertModel from a pre-trained model file.
Download and cache the pre-trained model file if needed.
"""
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name]
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
else:
vocab_file = pretrained_model_name
if os.path.isdir(vocab_file):
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
merges_file = os.path.join(vocab_file, MERGES_NAME)
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@@ -87,11 +85,12 @@ class OpenAIGPTTokenizer(object):
except FileNotFoundError:
logger.error(
"Model name '{}' was not found in model name list ({}). "
"We assumed '{}' was a path or url but couldn't find any file "
"associated to this path or url.".format(
pretrained_model_name,
"We assumed '{}' was a path or url but couldn't find files {} and {} "
"at this path or url.".format(
pretrained_model_name_or_path,
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
vocab_file))
pretrained_model_name_or_path,
vocab_file, merges_file))
return None
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
logger.info("loading vocabulary file {}".format(vocab_file))
@@ -101,29 +100,38 @@ class OpenAIGPTTokenizer(object):
vocab_file, resolved_vocab_file))
logger.info("loading merges file {} from cache at {}".format(
merges_file, resolved_merges_file))
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
# than the number of positional embeddings
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file):
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
try:
import ftfy
import spacy
except ImportError:
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
self.max_len = max_len if max_len is not None else int(1e12)
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
self.fix_text = ftfy.fix_text
self.encoder = json.load(open(vocab_file))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
if not special_tokens:
self.special_tokens = {}
else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def set_special_tokens(self, special_tokens):
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
@@ -168,20 +176,38 @@ class OpenAIGPTTokenizer(object):
self.cache[token] = word
return word
def tokenize(self, texts, verbose=True):
texts_tokens = []
if verbose:
for text in tqdm(texts, ncols=80, leave=False):
text = self.nlp(text_standardize(ftfy.fix_text(text)))
text_tokens = []
for token in text:
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
texts_tokens.append(text_tokens)
else:
for text in texts:
text = self.nlp(text_standardize(ftfy.fix_text(text)))
text_tokens = []
for token in text:
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
texts_tokens.append(text_tokens)
return texts_tokens
def tokenize(self, text):
split_tokens = []
text = self.nlp(text_standardize(self.fix_text(text)))
for token in text:
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
return split_tokens
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
else:
ids.append(self.encoder.get(token, 0))
if len(ids) > self.max_len:
raise ValueError(
"Token indices sequence length is longer than the specified maximum "
" sequence length for this BERT model ({} > {}). Running this"
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
)
return ids
def convert_ids_to_tokens(self, ids):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.decoder[i])
return tokens
def decode(self, ids):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(ids)
out_string = ''.join(tokens).replace('</w>', ' ')
return out_string