Sentence-pair tasks handling. Using common tests on RoBERTa. Forced push to fix indentation.
This commit is contained in:
@@ -22,22 +22,22 @@ import re
|
||||
from io import open
|
||||
import six
|
||||
|
||||
from .tokenization_utils import PreTrainedTokenizer
|
||||
from .tokenization_utils import PreTrainedTokenizer, clean_up_tokenization
|
||||
from .tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
'dict_file': 'dict.txt',
|
||||
'vocab_file': 'dict.txt',
|
||||
}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
'dict_file':
|
||||
{
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
},
|
||||
'vocab_file':
|
||||
{
|
||||
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
|
||||
},
|
||||
}
|
||||
|
||||
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
@@ -46,7 +46,6 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
|
||||
'roberta-large-mnli': 512,
|
||||
}
|
||||
|
||||
|
||||
SPACE_NORMALIZER = re.compile(r"\s+")
|
||||
|
||||
def tokenize_line(line):
|
||||
@@ -142,7 +141,7 @@ class Dictionary(object):
|
||||
"rebuild the dataset".format(f))
|
||||
return
|
||||
|
||||
lines = f.readlines()
|
||||
lines = f.read().splitlines()
|
||||
for line in lines:
|
||||
idx = line.rfind(' ')
|
||||
if idx == -1:
|
||||
@@ -152,7 +151,7 @@ class Dictionary(object):
|
||||
self.indices[word] = len(self.symbols)
|
||||
self.symbols.append(word)
|
||||
self.count.append(count)
|
||||
|
||||
|
||||
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
|
||||
consumer=None, append_eos=True, reverse_order=False):
|
||||
words = line_tokenizer(line)
|
||||
@@ -174,8 +173,6 @@ class Dictionary(object):
|
||||
return ids
|
||||
|
||||
|
||||
|
||||
|
||||
class RobertaTokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
RoBERTa tokenizer. Peculiarities:
|
||||
@@ -185,25 +182,53 @@ class RobertaTokenizer(PreTrainedTokenizer):
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, dict_file,
|
||||
def __init__(self, vocab_file,
|
||||
bos_token="<s>", eos_token="</s>", **kwargs):
|
||||
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
|
||||
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, **kwargs)
|
||||
|
||||
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
self.dictionary = Dictionary.load(dict_file)
|
||||
self.dictionary = Dictionary.load(vocab_file)
|
||||
|
||||
def _tokenize(self, text):
|
||||
""" Use GPT-2 Tokenizer """
|
||||
return self.gpt2_tokenizer._tokenize(text)
|
||||
|
||||
def encode(self, text):
|
||||
def encode(self, text, *args):
|
||||
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||
"""
|
||||
gpt2_tokens_joined = " ".join(
|
||||
str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text))
|
||||
)
|
||||
bpe_sentence = '<s> ' + gpt2_tokens_joined + ' </s>'
|
||||
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
|
||||
bpe_sentence = [self.cls_token] + \
|
||||
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text)) + \
|
||||
[self.sep_token]
|
||||
|
||||
if len(args):
|
||||
for additional_sentence in args:
|
||||
bpe_sentence += [self.sep_token
|
||||
] + \
|
||||
self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(additional_sentence)) + \
|
||||
[self.sep_token]
|
||||
|
||||
return self.dictionary.encode_line(' '.join([str(token) for token in bpe_sentence]), append_eos=False)
|
||||
|
||||
def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
|
||||
""" Converts a sequence of ids (integer) in a string, using the tokenizer and vocabulary
|
||||
with options to remove special tokens and clean up tokenization spaces.
|
||||
Handles sentence pairs.
|
||||
"""
|
||||
filtered_tokens = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
|
||||
|
||||
if any(isinstance(element, list) for element in filtered_tokens):
|
||||
texts = []
|
||||
for element in filtered_tokens:
|
||||
text = self.convert_tokens_to_string(element)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
texts.append(text)
|
||||
return texts
|
||||
else:
|
||||
text = self.convert_tokens_to_string(filtered_tokens)
|
||||
if clean_up_tokenization_spaces:
|
||||
text = clean_up_tokenization(text)
|
||||
return text
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
return self.dictionary.index(token)
|
||||
@@ -218,3 +243,24 @@ class RobertaTokenizer(PreTrainedTokenizer):
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
|
||||
|
||||
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
|
||||
# Remove the first and last tokens which are cls and sep tokens
|
||||
ids = ids[1:-1]
|
||||
# If multi sentence, then split (multi sentence found by looking for two sequential sep tokens)
|
||||
ids = [list(map(int, example.split(' '))) for example in ' '.join([str(id) for id in ids]).split(' 2 2 ')]
|
||||
|
||||
if len(ids) == 1:
|
||||
tokens = self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), ids[0])))
|
||||
else:
|
||||
tokens = []
|
||||
for example in ids:
|
||||
tokens += [
|
||||
self.gpt2_tokenizer.convert_ids_to_tokens(list(map(lambda id: int(self.dictionary[id]), example)))]
|
||||
return tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
tokens = " ".join(str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(tokens))
|
||||
bpe_sentence = '<s> ' + tokens + ' </s>'
|
||||
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user