This commit is contained in:
thomwolf
2019-09-02 02:27:39 +02:00
parent b6cd856b08
commit fede4ef45d
2 changed files with 25 additions and 4 deletions

View File

@@ -722,7 +722,7 @@ class PreTrainedTokenizer(object):
return self._convert_id_to_token(ids)
tokens = []
for index in ids:
if index in self.all_special_ids and skip_special_tokens:
if skip_special_tokens and index in self.all_special_ids:
continue
if index in self.added_tokens_decoder:
tokens.append(self.added_tokens_decoder[index])
@@ -747,7 +747,25 @@ class PreTrainedTokenizer(object):
Similar to doing ``self.convert_tokens_to_string(self.convert_ids_to_tokens(token_ids))``.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
text = self.convert_tokens_to_string(filtered_tokens)
# To avoid mixing byte-level and unicode for byte-level BPT
# we need to build string separatly for added tokens and byte-level tokens
# cf. https://github.com/huggingface/pytorch-transformers/issues/1133
sub_texts = []
current_sub_text = []
for token in filtered_tokens:
if skip_special_tokens and token in self.all_special_ids:
continue
if token in self.added_tokens_encoder:
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
current_sub_text = []
sub_texts.append(token)
else:
current_sub_text.append(token)
if current_sub_text:
sub_texts.append(self.convert_tokens_to_string(current_sub_text))
text = ''.join(sub_texts)
if self._sep_token is not None and self._sep_token in text:
text = text.replace(self._cls_token, self._sep_token)