added tests + fixed losses
This commit is contained in:
@@ -67,19 +67,17 @@ class OpenAIGPTTokenizer(object):
|
||||
mostly a wrapper for a public python bpe tokenizer
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name, cache_dir=None, *inputs, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
"""
|
||||
Instantiate a PreTrainedBertModel from a pre-trained model file.
|
||||
Download and cache the pre-trained model file if needed.
|
||||
"""
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name]
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
merges_file = PRETRAINED_MERGES_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||
else:
|
||||
vocab_file = pretrained_model_name
|
||||
if os.path.isdir(vocab_file):
|
||||
vocab_file = os.path.join(vocab_file, VOCAB_NAME)
|
||||
merges_file = os.path.join(vocab_file, MERGES_NAME)
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@@ -87,11 +85,12 @@ class OpenAIGPTTokenizer(object):
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
"Model name '{}' was not found in model name list ({}). "
|
||||
"We assumed '{}' was a path or url but couldn't find any file "
|
||||
"associated to this path or url.".format(
|
||||
pretrained_model_name,
|
||||
"We assumed '{}' was a path or url but couldn't find files {} and {} "
|
||||
"at this path or url.".format(
|
||||
pretrained_model_name_or_path,
|
||||
', '.join(PRETRAINED_VOCAB_ARCHIVE_MAP.keys()),
|
||||
vocab_file))
|
||||
pretrained_model_name_or_path,
|
||||
vocab_file, merges_file))
|
||||
return None
|
||||
if resolved_vocab_file == vocab_file and resolved_merges_file == merges_file:
|
||||
logger.info("loading vocabulary file {}".format(vocab_file))
|
||||
@@ -101,29 +100,38 @@ class OpenAIGPTTokenizer(object):
|
||||
vocab_file, resolved_vocab_file))
|
||||
logger.info("loading merges file {} from cache at {}".format(
|
||||
merges_file, resolved_merges_file))
|
||||
if pretrained_model_name in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP:
|
||||
# if we're using a pretrained model, ensure the tokenizer wont index sequences longer
|
||||
# than the number of positional embeddings
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name]
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file):
|
||||
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
except ImportError:
|
||||
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.fix_text = ftfy.fix_text
|
||||
self.encoder = json.load(open(vocab_file))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {}
|
||||
if not special_tokens:
|
||||
self.special_tokens = {}
|
||||
else:
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
|
||||
def set_special_tokens(self, special_tokens):
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
@@ -168,20 +176,38 @@ class OpenAIGPTTokenizer(object):
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def tokenize(self, texts, verbose=True):
|
||||
texts_tokens = []
|
||||
if verbose:
|
||||
for text in tqdm(texts, ncols=80, leave=False):
|
||||
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
||||
text_tokens = []
|
||||
for token in text:
|
||||
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
|
||||
texts_tokens.append(text_tokens)
|
||||
else:
|
||||
for text in texts:
|
||||
text = self.nlp(text_standardize(ftfy.fix_text(text)))
|
||||
text_tokens = []
|
||||
for token in text:
|
||||
text_tokens.extend([self.encoder.get(t, 0) for t in self.bpe(token.text.lower()).split(' ')])
|
||||
texts_tokens.append(text_tokens)
|
||||
return texts_tokens
|
||||
def tokenize(self, text):
|
||||
split_tokens = []
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
"""Converts a sequence of tokens into ids using the vocab."""
|
||||
ids = []
|
||||
for token in tokens:
|
||||
if token in self.special_tokens:
|
||||
ids.append(self.special_tokens[token])
|
||||
else:
|
||||
ids.append(self.encoder.get(token, 0))
|
||||
if len(ids) > self.max_len:
|
||||
raise ValueError(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this BERT model ({} > {}). Running this"
|
||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
def convert_ids_to_tokens(self, ids):
|
||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||
tokens = []
|
||||
for i in ids:
|
||||
tokens.append(self.decoder[i])
|
||||
return tokens
|
||||
|
||||
def decode(self, ids):
|
||||
"""Converts a sequence of ids in a string."""
|
||||
tokens = self.convert_ids_to_tokens(ids)
|
||||
out_string = ''.join(tokens).replace('</w>', ' ')
|
||||
return out_string
|
||||
|
||||
Reference in New Issue
Block a user