update the arguments add_prefix_space and trim_offsets in backend_tokenizer.post_processor of RobertaTokenizerFast (#14752)
* add tests * change post-processor, pre-tokenizer and decoder (can't update decoder) * update test (remove decoder which doesn't depend on trim and add_prefix) * just update the post_processor * fix change * `trim_offsets` has no influence on `pre_tokenizer` * remove a test that need some input from the `tokenizers` lib maintainers * format * add new test offsets roberta * polish comments
This commit is contained in:
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import unittest
|
||||
@@ -196,3 +197,107 @@ class RobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertSequenceEqual(
|
||||
tokens_r_str, ["<s>", "A", ",", "<mask>", "ĠAllen", "N", "LP", "Ġsentence", ".", "</s>"]
|
||||
)
|
||||
|
||||
def test_change_add_prefix_space_and_trim_offsets_args(self):
|
||||
for trim_offsets, add_prefix_space in itertools.product([True, False], repeat=2):
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
self.tmpdirname, use_fast=True, add_prefix_space=add_prefix_space, trim_offsets=trim_offsets
|
||||
)
|
||||
|
||||
pre_tokenizer_state = json.loads(tokenizer_r.backend_tokenizer.pre_tokenizer.__getstate__())
|
||||
post_processor_state = json.loads(tokenizer_r.backend_tokenizer.post_processor.__getstate__())
|
||||
|
||||
self.assertEqual(pre_tokenizer_state["add_prefix_space"], add_prefix_space)
|
||||
|
||||
self.assertEqual(post_processor_state["add_prefix_space"], add_prefix_space)
|
||||
self.assertEqual(post_processor_state["trim_offsets"], trim_offsets)
|
||||
|
||||
def test_offsets_mapping_with_different_add_prefix_space_and_trim_space_arguments(self):
|
||||
# Test which aims to verify that the offsets are well adapted to the argument `add_prefix_space` and
|
||||
# `trim_offsets`
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
text_of_1_token = "hello" # `hello` is a token in the vocabulary of `pretrained_name`
|
||||
text = f"{text_of_1_token} {text_of_1_token}"
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=True, trim_offsets=True
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(len(text_of_1_token) + 1, len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=False, trim_offsets=True
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(len(text_of_1_token) + 1, len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=True, trim_offsets=False
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(len(text_of_1_token), len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=False, trim_offsets=False
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(len(text_of_1_token), len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
text = f" {text}"
|
||||
|
||||
# tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
# pretrained_name, use_fast=True, add_prefix_space=True, trim_offsets=True
|
||||
# )
|
||||
# encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
# self.assertEqual(encoding.offset_mapping[0], (1, 1 + len(text_of_1_token)))
|
||||
# self.assertEqual(
|
||||
# encoding.offset_mapping[1],
|
||||
# (1 + len(text_of_1_token) + 1, 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
# )
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=False, trim_offsets=True
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (1, 1 + len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(1 + len(text_of_1_token) + 1, 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=True, trim_offsets=False
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, 1 + len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(1 + len(text_of_1_token), 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, use_fast=True, add_prefix_space=False, trim_offsets=False
|
||||
)
|
||||
encoding = tokenizer_r(text, return_offsets_mapping=True, add_special_tokens=False)
|
||||
self.assertEqual(encoding.offset_mapping[0], (0, 1 + len(text_of_1_token)))
|
||||
self.assertEqual(
|
||||
encoding.offset_mapping[1],
|
||||
(1 + len(text_of_1_token), 1 + len(text_of_1_token) + 1 + len(text_of_1_token)),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user