From 05f961840b0901b3689de20d7b18ed07b24be5e1 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 4 Feb 2019 13:06:19 +0100 Subject: [PATCH] logging --- .../tokenization_openai.py | 21 ++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index e5e4dbda39..a12e58721b 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -125,16 +125,19 @@ class OpenAIGPTTokenizer(object): merges = [tuple(merge.split()) for merge in merges] self.bpe_ranks = dict(zip(merges, range(len(merges)))) self.cache = {} - if not special_tokens: - self.special_tokens = {} - else: - self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + self.set_special_tokens(special_tokens) def __len__(self): return len(self.encoder) + len(self.special_tokens) def set_special_tokens(self, special_tokens): + if not special_tokens: + self.special_tokens = {} + self.special_tokens_decoder = {} + return self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens)) + self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()} + logger.info("Special tokens {}".format(self.special_tokens)) def bpe(self, token): word = tuple(token[:-1]) + ( token[-1] + '',) @@ -189,6 +192,11 @@ class OpenAIGPTTokenizer(object): def convert_tokens_to_ids(self, tokens): """Converts a sequence of tokens into ids using the vocab.""" ids = [] + if isinstance(tokens, str): + if tokens in self.special_tokens: + return self.special_tokens[tokens] + else: + return self.encoder.get(tokens, 0) for token in tokens: if token in self.special_tokens: ids.append(self.special_tokens[token]) @@ -206,7 +214,10 @@ class OpenAIGPTTokenizer(object): """Converts a sequence of ids in BPE tokens using the vocab.""" tokens = [] for i in ids: - tokens.append(self.decoder[i]) + if i in self.special_tokens_decoder: + tokens.append(self.special_tokens_decoder[i]) + else: + tokens.append(self.decoder[i]) return tokens def decode(self, ids):