Add FastTokenizer to REALM (#15211)

* Remove BertTokenizer abstraction

* Add FastTokenizer to REALM

* Fix config archive map

* Fix copies

* Update realm.mdx

* Apply suggestions from code review
This commit is contained in:
Li-Huai (Allan) Lin
2022-01-19 22:19:36 +08:00
committed by GitHub
parent 021b52e7a8
commit 841d979190
10 changed files with 824 additions and 46 deletions

View File

@@ -16,6 +16,7 @@
import os
import unittest
from transformers import RealmTokenizerFast
from transformers.models.bert.tokenization_bert import (
VOCAB_FILES_NAMES,
BasicTokenizer,
@@ -34,8 +35,8 @@ from .test_tokenization_common import TokenizerTesterMixin, filter_non_english
class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = RealmTokenizer
rust_tokenizer_class = None
test_rust_tokenizer = False
rust_tokenizer_class = RealmTokenizerFast
test_rust_tokenizer = True
space_between_special_tokens = True
from_pretrained_filter = filter_non_english
@@ -301,14 +302,21 @@ class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
@slow
def test_batch_encode_candidates(self):
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
encoded_sentence_r = tokenizer_r.batch_encode_candidates(text, max_length=10, return_tensors="np")
encoded_sentence_p = tokenizer_p.batch_encode_candidates(text, max_length=10, return_tensors="np")
encoded_sentence = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
expected_shape = (2, 2, 10)
expected_shape = (2, 2, 10)
self.assertEqual(encoded_sentence_r["input_ids"].shape, expected_shape)
self.assertEqual(encoded_sentence_r["attention_mask"].shape, expected_shape)
self.assertEqual(encoded_sentence_r["token_type_ids"].shape, expected_shape)
assert encoded_sentence["input_ids"].shape == expected_shape
assert encoded_sentence["attention_mask"].shape == expected_shape
assert encoded_sentence["token_type_ids"].shape == expected_shape
self.assertEqual(encoded_sentence_p["input_ids"].shape, expected_shape)
self.assertEqual(encoded_sentence_p["attention_mask"].shape, expected_shape)
self.assertEqual(encoded_sentence_p["token_type_ids"].shape, expected_shape)