This commit is contained in:
thomwolf
2019-02-04 13:06:19 +01:00
parent 3a848111e6
commit 05f961840b

View File

@@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object):
merges = [tuple(merge.split()) for merge in merges]
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {}
if not special_tokens:
self.special_tokens = {}
else:
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.set_special_tokens(special_tokens)
def __len__(self):
return len(self.encoder) + len(self.special_tokens)
def set_special_tokens(self, special_tokens):
if not special_tokens:
self.special_tokens = {}
self.special_tokens_decoder = {}
return
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
logger.info("Special tokens {}".format(self.special_tokens))
def bpe(self, token):
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
@@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object):
def convert_tokens_to_ids(self, tokens):
"""Converts a sequence of tokens into ids using the vocab."""
ids = []
if isinstance(tokens, str):
if tokens in self.special_tokens:
return self.special_tokens[tokens]
else:
return self.encoder.get(tokens, 0)
for token in tokens:
if token in self.special_tokens:
ids.append(self.special_tokens[token])
@@ -206,7 +214,10 @@ class OpenAIGPTTokenizer(object):
"""Converts a sequence of ids in BPE tokens using the vocab."""
tokens = []
for i in ids:
tokens.append(self.decoder[i])
if i in self.special_tokens_decoder:
tokens.append(self.special_tokens_decoder[i])
else:
tokens.append(self.decoder[i])
return tokens
def decode(self, ids):