Add FastTokenizer to REALM (#15211)
* Remove BertTokenizer abstraction * Add FastTokenizer to REALM * Fix config archive map * Fix copies * Update realm.mdx * Apply suggestions from code review
This commit is contained in:
committed by
GitHub
parent
021b52e7a8
commit
841d979190
@@ -16,6 +16,7 @@
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from transformers import RealmTokenizerFast
|
||||
from transformers.models.bert.tokenization_bert import (
|
||||
VOCAB_FILES_NAMES,
|
||||
BasicTokenizer,
|
||||
@@ -34,8 +35,8 @@ from .test_tokenization_common import TokenizerTesterMixin, filter_non_english
|
||||
class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
tokenizer_class = RealmTokenizer
|
||||
rust_tokenizer_class = None
|
||||
test_rust_tokenizer = False
|
||||
rust_tokenizer_class = RealmTokenizerFast
|
||||
test_rust_tokenizer = True
|
||||
space_between_special_tokens = True
|
||||
from_pretrained_filter = filter_non_english
|
||||
|
||||
@@ -301,14 +302,21 @@ class RealmTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
|
||||
@slow
|
||||
def test_batch_encode_candidates(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("bert-base-uncased")
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
|
||||
text = [["Hello world!", "Nice to meet you!"], ["The cute cat.", "The adorable dog."]]
|
||||
encoded_sentence_r = tokenizer_r.batch_encode_candidates(text, max_length=10, return_tensors="np")
|
||||
encoded_sentence_p = tokenizer_p.batch_encode_candidates(text, max_length=10, return_tensors="np")
|
||||
|
||||
encoded_sentence = tokenizer.batch_encode_candidates(text, max_length=10, return_tensors="pt")
|
||||
expected_shape = (2, 2, 10)
|
||||
|
||||
expected_shape = (2, 2, 10)
|
||||
self.assertEqual(encoded_sentence_r["input_ids"].shape, expected_shape)
|
||||
self.assertEqual(encoded_sentence_r["attention_mask"].shape, expected_shape)
|
||||
self.assertEqual(encoded_sentence_r["token_type_ids"].shape, expected_shape)
|
||||
|
||||
assert encoded_sentence["input_ids"].shape == expected_shape
|
||||
assert encoded_sentence["attention_mask"].shape == expected_shape
|
||||
assert encoded_sentence["token_type_ids"].shape == expected_shape
|
||||
self.assertEqual(encoded_sentence_p["input_ids"].shape, expected_shape)
|
||||
self.assertEqual(encoded_sentence_p["attention_mask"].shape, expected_shape)
|
||||
self.assertEqual(encoded_sentence_p["token_type_ids"].shape, expected_shape)
|
||||
|
||||
Reference in New Issue
Block a user