logging
This commit is contained in:
@@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object):
|
|||||||
merges = [tuple(merge.split()) for merge in merges]
|
merges = [tuple(merge.split()) for merge in merges]
|
||||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
if not special_tokens:
|
self.set_special_tokens(special_tokens)
|
||||||
self.special_tokens = {}
|
|
||||||
else:
|
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.encoder) + len(self.special_tokens)
|
return len(self.encoder) + len(self.special_tokens)
|
||||||
|
|
||||||
def set_special_tokens(self, special_tokens):
|
def set_special_tokens(self, special_tokens):
|
||||||
|
if not special_tokens:
|
||||||
|
self.special_tokens = {}
|
||||||
|
self.special_tokens_decoder = {}
|
||||||
|
return
|
||||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||||
|
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||||
|
logger.info("Special tokens {}".format(self.special_tokens))
|
||||||
|
|
||||||
def bpe(self, token):
|
def bpe(self, token):
|
||||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||||
@@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object):
|
|||||||
def convert_tokens_to_ids(self, tokens):
|
def convert_tokens_to_ids(self, tokens):
|
||||||
"""Converts a sequence of tokens into ids using the vocab."""
|
"""Converts a sequence of tokens into ids using the vocab."""
|
||||||
ids = []
|
ids = []
|
||||||
|
if isinstance(tokens, str):
|
||||||
|
if tokens in self.special_tokens:
|
||||||
|
return self.special_tokens[tokens]
|
||||||
|
else:
|
||||||
|
return self.encoder.get(tokens, 0)
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
if token in self.special_tokens:
|
if token in self.special_tokens:
|
||||||
ids.append(self.special_tokens[token])
|
ids.append(self.special_tokens[token])
|
||||||
@@ -206,6 +214,9 @@ class OpenAIGPTTokenizer(object):
|
|||||||
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
"""Converts a sequence of ids in BPE tokens using the vocab."""
|
||||||
tokens = []
|
tokens = []
|
||||||
for i in ids:
|
for i in ids:
|
||||||
|
if i in self.special_tokens_decoder:
|
||||||
|
tokens.append(self.special_tokens_decoder[i])
|
||||||
|
else:
|
||||||
tokens.append(self.decoder[i])
|
tokens.append(self.decoder[i])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user