From 75d5f98fd2a154bb5bfc0879c4a6e389c6789be5 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Fri, 9 Aug 2019 15:02:13 -0400 Subject: [PATCH] Roberta tokenization + fixed tests (py3 + py2). --- .../tests/modeling_roberta_test.py | 40 +-- .../tests/tokenization_roberta_test.py | 11 +- pytorch_transformers/tokenization_roberta.py | 311 ++++++++---------- 3 files changed, 138 insertions(+), 224 deletions(-) diff --git a/pytorch_transformers/tests/modeling_roberta_test.py b/pytorch_transformers/tests/modeling_roberta_test.py index e0455d8508..94035e9667 100644 --- a/pytorch_transformers/tests/modeling_roberta_test.py +++ b/pytorch_transformers/tests/modeling_roberta_test.py @@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask} return config, inputs_dict - def test_inference_masked_lm(self): - model = RobertaForMaskedLM.from_pretrained('roberta-base') - - input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - output = model(input_ids)[0] - expected_shape = torch.Size((1, 11, 50265)) - self.assertEqual( - output.shape, - expected_shape - ) - # compare the actual values for a slice. - expected_slice = torch.Tensor( - [[[33.8843, -4.3107, 22.7779], - [4.6533, -2.8099, 13.6252], - [1.8222, -3.6898, 8.8600]]] - ) - self.assertTrue( - torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3) - ) - - # @pytest.mark.slow - def test_inference_no_head(self): - model = RobertaModel.from_pretrained('roberta-base') - - input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]]) - output = model(input_ids)[0] - # compare the actual values for a slice. - expected_slice = torch.Tensor( - [[[-0.0231, 0.0782, 0.0074], - [-0.1854, 0.0539, -0.0174], - [0.0548, 0.0799, 0.1687]]] - ) - self.assertTrue( - torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3) - ) - def setUp(self): self.model_tester = RobertaModelTest.RobertaModelTester(self) self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37) @@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester): class RobertaModelIntegrationTest(unittest.TestCase): - # @pytest.mark.slow + @pytest.mark.slow def test_inference_masked_lm(self): model = RobertaForMaskedLM.from_pretrained('roberta-base') @@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase): torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3) ) - # @pytest.mark.slow + @pytest.mark.slow def test_inference_no_head(self): model = RobertaModel.from_pretrained('roberta-base') diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index fbb3f8381d..daefea0fa7 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -18,8 +18,7 @@ import os import json import unittest -from pytorch_transformers.tokenization_roberta import RobertaTokenizer, DICT_FILES_NAMES -from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES +from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES from .tokenization_tests_commons import CommonTestCases @@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): fp.write("\n".join(merges)) def get_tokenizer(self): - bpe_tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) - return RobertaTokenizer.from_pretrained("roberta-base", bpe_tokenizer=bpe_tokenizer) + return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map) def get_input_output_texts(self): input_text = u"lower newer" @@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): return input_text, output_text def test_full_tokenizer(self): - tokenizer = self.get_tokenizer() + tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map) text = "lower" bpe_tokens = ["low", "er"] tokens = tokenizer.tokenize(text) self.assertListEqual(tokens, bpe_tokens) input_tokens = tokens + [tokenizer.unk_token] - input_bpe_tokens = [0, 4, 12, 176, 2] - tokenizer.convert_tokens_to_ids(input_tokens) + input_bpe_tokens = [13, 12, 17] self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) diff --git a/pytorch_transformers/tokenization_roberta.py b/pytorch_transformers/tokenization_roberta.py index 4ec53a65b0..b01b92653d 100644 --- a/pytorch_transformers/tokenization_roberta.py +++ b/pytorch_transformers/tokenization_roberta.py @@ -12,229 +12,182 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Tokenization classes for RoBERTa.""" +"""Tokenization classes for OpenAI GPT.""" from __future__ import (absolute_import, division, print_function, unicode_literals) +import sys import json import logging -import re -from io import open -import six import os +import regex as re +from io import open +from .tokenization_gpt2 import bytes_to_unicode, get_pairs from .tokenization_utils import PreTrainedTokenizer -from .tokenization_gpt2 import GPT2Tokenizer + +try: + from functools import lru_cache +except ImportError: + # Just a dummy decorator to get the checks to run on python2 + # because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now. + def lru_cache(): + return lambda func: func logger = logging.getLogger(__name__) -DICT_FILES_NAMES = { - 'dict_file': 'dict.txt', +VOCAB_FILES_NAMES = { + 'vocab_file': 'vocab.json', + 'merges_file': 'merges.txt', } -PRETRAINED_DICT_FILES_MAP = { - 'dict_file': - { - 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", - 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", - 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt", - }, +PRETRAINED_VOCAB_FILES_MAP = { + 'vocab_file': + { + 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json", + 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json", + 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json", + }, + 'merges_file': + { + 'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt", + 'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt", + 'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt", + }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { - 'roberta-base': 512, - 'roberta-large': 512, - 'roberta-large-mnli': 512, + 'roberta-base': 1024, + 'roberta-large': 1024, + 'roberta-large-mnli': 1024, } -SPACE_NORMALIZER = re.compile(r"\s+") - -def tokenize_line(line): - line = SPACE_NORMALIZER.sub(" ", line) - line = line.strip() - return line.split() - - -class Dictionary(object): - """ - A mapping from symbols to consecutive integers - - From Facebook's fairseq. - """ - - def __init__( - self, - pad='', - eos='', - unk='', - bos='', - extra_special_symbols=None, - ): - self.unk_word, self.pad_word, self.eos_word = unk, pad, eos - self.symbols = [] - self.count = [] - self.indices = {} - self.bos_index = self.add_symbol(bos) - self.pad_index = self.add_symbol(pad) - self.eos_index = self.add_symbol(eos) - self.unk_index = self.add_symbol(unk) - if extra_special_symbols: - for s in extra_special_symbols: - self.add_symbol(s) - self.nspecial = len(self.symbols) - - def __getitem__(self, idx): - if idx < len(self.symbols): - return self.symbols[idx] - return self.unk_word - - def index(self, sym): - """Returns the index of the specified symbol""" - assert isinstance(sym, str) - if sym in self.indices: - return self.indices[sym] - return self.unk_index - - def add_symbol(self, word, n=1): - """Adds a word to the dictionary""" - if word in self.indices: - idx = self.indices[word] - self.count[idx] = self.count[idx] + n - return idx - else: - idx = len(self.symbols) - self.indices[word] = idx - self.symbols.append(word) - self.count.append(n) - return idx - - @classmethod - def load(cls, f, ignore_utf_errors=False): - """Loads the dictionary from a text file with the format: - - ``` - - - ... - ``` - """ - d = cls() - d.add_from_file(f, ignore_utf_errors) - return d - - def add_from_file(self, f, ignore_utf_errors=False): - """ - Loads a pre-existing dictionary from a text file and adds its symbols - to this instance. - """ - if isinstance(f, six.string_types): - try: - if not ignore_utf_errors: - with open(f, 'r', encoding='utf-8') as fd: - self.add_from_file(fd) - else: - with open(f, 'r', encoding='utf-8', errors='ignore') as fd: - self.add_from_file(fd) - except FileNotFoundError as fnfe: - raise fnfe - except UnicodeError: - raise Exception("Incorrect encoding detected in {}, please " - "rebuild the dataset".format(f)) - return - - lines = f.read().splitlines() - for line in lines: - idx = line.rfind(' ') - if idx == -1: - raise ValueError("Incorrect dictionary format, expected ' '") - word = line[:idx] - count = int(line[idx + 1:]) - self.indices[word] = len(self.symbols) - self.symbols.append(word) - self.count.append(count) - - def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True, - consumer=None, append_eos=True, reverse_order=False): - words = line_tokenizer(line) - if reverse_order: - words = list(reversed(words)) - nwords = len(words) - ids = [0] * (nwords + 1 if append_eos else nwords) - - for i, word in enumerate(words): - if add_if_not_exist: - idx = self.add_symbol(word) - else: - idx = self.index(word) - if consumer is not None: - consumer(word, idx) - ids[i] = idx - if append_eos: - ids[nwords] = self.eos_index - return ids - class RobertaTokenizer(PreTrainedTokenizer): """ - RoBERTa tokenizer. Peculiarities: - - GPT-2 tokenizer with a different integer mapping on top. + GPT-2 BPE tokenizer. Peculiarities: + - Byte-level BPE """ - vocab_files_names = DICT_FILES_NAMES - pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP + vocab_files_names = VOCAB_FILES_NAMES + pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES - def __init__(self, dict_file, bpe_tokenizer=None, bos_token="", eos_token="", sep_token="", cls_token="", - unk_token="", **kwargs): - super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token, - unk_token=unk_token, **kwargs) + def __init__(self, vocab_file, merges_file, errors='replace', bos_token="", eos_token="", sep_token="", + cls_token="", unk_token="", **kwargs): + super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token, + sep_token=sep_token, cls_token=cls_token, **kwargs) - self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer - self.dictionary = Dictionary.load(dict_file) + self.encoder = json.load(open(vocab_file, encoding="utf-8")) + self.decoder = {v: k for k, v in self.encoder.items()} + self.errors = errors # how to handle errors in decoding + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1] + bpe_merges = [tuple(merge.split()) for merge in bpe_data] + self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) + self.cache = {} + + # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions + self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") @property def vocab_size(self): - return len(self.dictionary.indices) + return len(self.encoder) + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token) + pairs = get_pairs(word) + + if not pairs: + return token + + while True: + bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word)-1 and word[i+1] == second: + new_word.append(first+second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = ' '.join(word) + self.cache[token] = word + return word def _tokenize(self, text): - """ Use GPT-2 Tokenizer """ - return self.gpt2_tokenizer._tokenize(text) + """ Tokenize a string. """ + bpe_tokens = [] + for token in re.findall(self.pat, text): + if sys.version_info[0] == 2: + token = ''.join(self.byte_encoder[ord(b)] for b in token) + else: + token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) + bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' ')) + return bpe_tokens def _convert_token_to_id(self, token): - if self.dictionary.index(token) != 3: - return self.dictionary.index(token) - return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token))) + """ Converts a token (str/unicode) in an id using the vocab. """ + return self.encoder.get(token, self.encoder.get(self.unk_token)) def _convert_id_to_token(self, index): - symbol = self.dictionary[index] - try: - idx = int(symbol) - return self.gpt2_tokenizer._convert_id_to_token(idx) - except ValueError: - return symbol + """Converts an index (integer) in a token (string/unicode) using the vocab.""" + return self.decoder.get(index) def convert_tokens_to_string(self, tokens): - return self.gpt2_tokenizer.convert_tokens_to_string(tokens) + """ Converts a sequence of tokens (string) in a single string. """ + text = ''.join(tokens) + text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors) + return text - def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False): - cls = [self._convert_token_to_id(self.cls_token)] - tokens = super().convert_tokens_to_ids(tokens) + def add_special_tokens_single_sentence(self, token_ids): + return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)] + + def add_special_tokens_sentences_pair(self, *token_ids): sep = [self._convert_token_to_id(self.sep_token)] - return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens - - def convert_ids_to_tokens(self, ids, skip_special_tokens=False): - return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1] + cls = [self._convert_token_to_id(self.cls_token)] + return cls + token_ids[0] + sep + sep + token_ids[1] + sep def save_vocabulary(self, save_directory): """Save the tokenizer vocabulary and merge files to a directory.""" if not os.path.isdir(save_directory): logger.error("Vocabulary path ({}) should be a directory".format(save_directory)) return - dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file']) + vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file']) + merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file']) - with open(dict_file, 'w', encoding='utf-8') as f: - for i in range(self.dictionary.nspecial, len(self.dictionary.count)): - f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n") + with open(vocab_file, 'w', encoding='utf-8') as f: + f.write(json.dumps(self.encoder, ensure_ascii=False)) - vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory) + index = 0 + with open(merge_file, "w", encoding="utf-8") as writer: + writer.write(u'#version: 0.2\n') + for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]): + if index != token_index: + logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive." + " Please check that the tokenizer is not corrupted!".format(merge_file)) + index = token_index + writer.write(' '.join(bpe_tokens) + u'\n') + index += 1 - return vocab_files + (dict_file,) + return vocab_file, merge_file