Exposing prepare_for_model for both slow & fast tokenizers (#5479)

* Exposing prepare_for_model for both slow & fast tokenizers

* Update method signature

* The traditional style commit

* Hide the warnings behind the verbose flag

* update default truncation strategy and prepare_for_model

* fix tests and prepare_for_models methods

Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
Lysandre Debut
2020-07-03 10:51:21 -04:00
committed by GitHub
parent 814ed7ee76
commit 17ade127b9
4 changed files with 285 additions and 205 deletions

View File

@@ -508,9 +508,7 @@ class TokenizerTesterMixin:
self.assertEqual(len(truncated_sequence), total_length - 2)
self.assertEqual(truncated_sequence, sequence[:-2])
self.assertEqual(
len(overflowing_tokens), 0
) # No overflowing tokens when using 'longest' in python tokenizers
self.assertEqual(len(overflowing_tokens), 2 + stride)
def test_maximum_encoding_length_pair_input(self):
tokenizers = self.get_tokenizers(do_lower_case=False, model_max_length=100)
@@ -634,7 +632,39 @@ class TokenizerTesterMixin:
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(
len(overflowing_tokens), 0
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information = tokenizer.encode_plus(
seq_0,
seq_1,
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
return_overflowing_tokens=True,
# add_prefix_space=False,
)
# Overflowing tokens are handled quite differently in slow and fast tokenizers
if isinstance(tokenizer, PreTrainedTokenizerFast):
truncated_sequence = information["input_ids"][0]
overflowing_tokens = information["input_ids"][1]
self.assertEqual(len(information["input_ids"]), 2)
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(len(overflowing_tokens), 2 + stride + len(smallest))
self.assertEqual(overflowing_tokens, overflow_longest_sequence)
else:
truncated_sequence = information["input_ids"]
overflowing_tokens = information["overflowing_tokens"]
self.assertEqual(len(truncated_sequence), len(sequence) - 2)
self.assertEqual(truncated_sequence, truncated_longest_sequence)
self.assertEqual(
len(overflowing_tokens), 2 + stride
) # No overflowing tokens when using 'longest' in python tokenizers
information_first_truncated = tokenizer.encode_plus(
@@ -643,7 +673,7 @@ class TokenizerTesterMixin:
max_length=len(sequence) - 2,
add_special_tokens=False,
stride=stride,
truncation=True,
truncation="only_first",
return_overflowing_tokens=True,
# add_prefix_space=False,
)
@@ -1293,6 +1323,16 @@ class TokenizerTesterMixin:
for key in output.keys():
self.assertEqual(output[key], output_sequence[key])
def test_prepare_for_model(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:
string_sequence = "Testing the prepare_for_model method."
ids = tokenizer.encode(string_sequence, add_special_tokens=False)
input_dict = tokenizer.encode_plus(string_sequence)
prepared_input_dict = tokenizer.prepare_for_model(ids)
self.assertEqual(input_dict, prepared_input_dict)
@require_torch
@require_tf
def test_batch_encode_plus_tensors(self):

View File

@@ -90,6 +90,7 @@ class CommonFastTokenizerTest(unittest.TestCase):
self.assert_embeded_special_tokens(tokenizer_r, tokenizer_p)
self.assert_padding(tokenizer_r, tokenizer_p)
self.assert_create_token_type_ids(tokenizer_r, tokenizer_p)
self.assert_prepare_for_model(tokenizer_r, tokenizer_p)
# TODO: enable for v3.0.0
# self.assert_empty_output_no_special_tokens(tokenizer_r, tokenizer_p)
@@ -709,6 +710,12 @@ class CommonFastTokenizerTest(unittest.TestCase):
for i_no, i_with in zip(no_special_tokens[key], with_special_tokens[key]):
self.assertEqual(len(i_no), len(i_with) - simple_num_special_tokens_to_add)
def assert_prepare_for_model(self, tokenizer_r, tokenizer_p):
string_sequence = "Asserting that both tokenizers are equal"
python_output = tokenizer_p.prepare_for_model(tokenizer_p.encode(string_sequence))
rust_output = tokenizer_r.prepare_for_model(tokenizer_r.encode(string_sequence))
self.assertEqual(python_output, rust_output)
class WordPieceFastTokenizerTest(CommonFastTokenizerTest):
"""