From 01a3966bc6d265aa8c7088b39bfdc20a905a2c74 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Mon, 4 Feb 2019 17:26:25 +0100 Subject: [PATCH] more options on special tokens --- pytorch_pretrained_bert/tokenization_openai.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pytorch_pretrained_bert/tokenization_openai.py b/pytorch_pretrained_bert/tokenization_openai.py index a12e58721b..e545e0d375 100644 --- a/pytorch_pretrained_bert/tokenization_openai.py +++ b/pytorch_pretrained_bert/tokenization_openai.py @@ -131,6 +131,10 @@ class OpenAIGPTTokenizer(object): return len(self.encoder) + len(self.special_tokens) def set_special_tokens(self, special_tokens): + """ Add a list of additional tokens to the encoder. + The additional tokens are indexed starting from the last index of the + current vocabulary in the order of the `special_tokens` list. + """ if not special_tokens: self.special_tokens = {} self.special_tokens_decoder = {} @@ -210,18 +214,19 @@ class OpenAIGPTTokenizer(object): ) return ids - def convert_ids_to_tokens(self, ids): + def convert_ids_to_tokens(self, ids, skip_special_tokens=False): """Converts a sequence of ids in BPE tokens using the vocab.""" tokens = [] for i in ids: if i in self.special_tokens_decoder: - tokens.append(self.special_tokens_decoder[i]) + if not skip_special_tokens: + tokens.append(self.special_tokens_decoder[i]) else: tokens.append(self.decoder[i]) return tokens - def decode(self, ids): + def decode(self, ids, skip_special_tokens=False): """Converts a sequence of ids in a string.""" - tokens = self.convert_ids_to_tokens(ids) + tokens = self.convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens) out_string = ''.join(tokens).replace('', ' ') return out_string