Exposing prepare_for_model for both slow & fast tokenizers (#5479)
* Exposing prepare_for_model for both slow & fast tokenizers * Update method signature * The traditional style commit * Hide the warnings behind the verbose flag * update default truncation strategy and prepare_for_model * fix tests and prepare_for_models methods Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -508,9 +508,7 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(len(truncated_sequence), total_length - 2)
|
||||
self.assertEqual(truncated_sequence, sequence[:-2])
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 0
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride)
|
||||
|
||||
def test_maximum_encoding_length_pair_input(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
|
||||
@@ -634,7 +632,39 @@ class TokenizerTesterMixin:
|
||||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 0
|
||||
len(overflowing_tokens), 2 + stride
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
|
||||
information = tokenizer.encode_plus(
|
||||
seq_0,
|
||||
seq_1,
|
||||
max_length=len(sequence) - 2,
|
||||
add_special_tokens=False,
|
||||
stride=stride,
|
||||
truncation=True,
|
||||
return_overflowing_tokens=True,
|
||||
# add_prefix_space=False,
|
||||
)
|
||||
# Overflowing tokens are handled quite differently in slow and fast tokenizers
|
||||
if isinstance(tokenizer, PreTrainedTokenizerFast):
|
||||
truncated_sequence = information["input_ids"][0]
|
||||
overflowing_tokens = information["input_ids"][1]
|
||||
self.assertEqual(len(information["input_ids"]), 2)
|
||||
|
||||
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
|
||||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
|
||||
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
|
||||
else:
|
||||
truncated_sequence = information["input_ids"]
|
||||
overflowing_tokens = information["overflowing_tokens"]
|
||||
|
||||
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
|
||||
self.assertEqual(truncated_sequence, truncated_longest_sequence)
|
||||
|
||||
self.assertEqual(
|
||||
len(overflowing_tokens), 2 + stride
|
||||
) # No overflowing tokens when using 'longest' in python tokenizers
|
||||
|
||||
information_first_truncated = tokenizer.encode_plus(
|
||||
@@ -643,7 +673,7 @@ class TokenizerTesterMixin:
|
||||
max_length=len(sequence) - 2,
|
||||
add_special_tokens=False,
|
||||
stride=stride,
|
||||
truncation=True,
|
||||
truncation="only_first",
|
||||
return_overflowing_tokens=True,
|
||||
# add_prefix_space=False,
|
||||
)
|
||||
@@ -1293,6 +1323,16 @@ class TokenizerTesterMixin:
|
||||
for key in output.keys():
|
||||
self.assertEqual(output[key], output_sequence[key])
|
||||
|
||||
def test_prepare_for_model(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
string_sequence = "Testing the prepare_for_model method."
|
||||
ids = tokenizer.encode(string_sequence, add_special_tokens=False)
|
||||
input_dict = tokenizer.encode_plus(string_sequence)
|
||||
prepared_input_dict = tokenizer.prepare_for_model(ids)
|
||||
|
||||
self.assertEqual(input_dict, prepared_input_dict)
|
||||
|
||||
@require_torch
|
||||
@require_tf
|
||||
def test_batch_encode_plus_tensors(self):
|
||||
|
||||
@@ -90,6 +90,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
|
||||
self.assert_padding(tokenizer_r, tokenizer_p)
|
||||
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
|
||||
self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
|
||||
# TODO: enable for v3.0.0
|
||||
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
|
||||
|
||||
@@ -709,6 +710,12 @@ class CommonFastTokenizerTest(unittest.TestCase):
|
||||
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
|
||||
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
|
||||
|
||||
def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
|
||||
string_sequence = "Asserting that both tokenizers are equal"
|
||||
python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
|
||||
rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
|
||||
self.assertEqual(python_output, rust_output)
|
||||
|
||||
|
||||
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user