Encode and Decode are back in the superclass. They now handle sentence pairs special tokens.

This commit is contained in:
LysandreJik
2019-08-08 18:20:32 -04:00
parent e367ac469c
commit 6c41a8f5dc
4 changed files with 81 additions and 79 deletions

View File

@@ -21,18 +21,19 @@ import logging
import re
from io import open
import six
import os
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'vocab_file': 'dict.txt',
DICT_FILES_NAMES = {
'dict_file': 'dict.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
PRETRAINED_DICT_FILES_MAP = {
'dict_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
@@ -178,89 +179,62 @@ class RobertaTokenizer(PreTrainedTokenizer):
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
vocab_files_names = DICT_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, vocab_file,
bos_token="<s>", eos_token="</s>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, **kwargs)
def __init__(self, dict_file, bpe_tokenizer=None, bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>",
unk_token="<unk>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token,
unk_token=unk_token, **kwargs)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.dictionary = Dictionary.load(vocab_file)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer
self.dictionary = Dictionary.load(dict_file)
@property
def vocab_size(self):
return len(self.dictionary.indices)
def _tokenize(self, text):
""" Use GPT-2 Tokenizer """
return self.gpt2_tokenizer._tokenize(text)
def encode(self, text, *args):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
"""
bpe_sentence = [self.cls_token] + \
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text)) + \
[self.sep_token]
if len(args):
for additional_sentence in args:
bpe_sentence += [self.sep_token
] + \
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(additional_sentence)) + \
[self.sep_token]
return self.dictionary.encode_line(' '.join([str(token) for token in bpe_sentence]), append_eos=False)
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
with options to remove special tokens and clean up tokenization spaces.
Handles sentence pairs.
"""
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
if any(isinstance(element, list) for element in filtered_tokens):
texts = []
for element in filtered_tokens:
text = self.convert_tokens_to_string(element)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
texts.append(text)
return texts
else:
text = self.convert_tokens_to_string(filtered_tokens)
if clean_up_tokenization_spaces:
text = clean_up_tokenization(text)
return text
def _convert_token_to_id(self, token):
return self.dictionary.index(token)
if self.dictionary.index(token) != 3:
return self.dictionary.index(token)
return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
def _convert_id_to_token(self, index):
symbol = self.dictionary[index]
try:
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
except:
except ValueError:
return symbol
def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False):
cls = [self._convert_token_to_id(self.cls_token)]
tokens = super().convert_tokens_to_ids(tokens)
sep = [self._convert_token_to_id(self.sep_token)]
return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
# Remove the first and last tokens which are cls and sep tokens
ids = ids[1:-1]
# If multi sentence, then split (multi sentence found by looking for two sequential sep tokens)
ids = [list(map(int, example.split(' '))) for example in ' '.join([str(id) for id in ids]).split(' 2 2 ')]
return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1]
if len(ids) == 1:
tokens = self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), ids[0])))
else:
tokens = []
for example in ids:
tokens += [
self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), example)))]
return tokens
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file'])
def convert_tokens_to_ids(self, tokens):
tokens = " ".join(str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(tokens))
bpe_sentence = '<s> ' + tokens + ' </s>'
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
with open(dict_file, 'w', encoding='utf-8') as f:
for i in range(self.dictionary.nspecial, len(self.dictionary.count)):
f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n")
vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory)
return vocab_files + (dict_file,)