[ PreTrainedTokenizerFast] Keep properties from fast tokenizer (#25053)
* draft solution * use `setdefault` * nits * add tests and fix truncation issue * fix test * test passes locally * quality * updates * update tsets
This commit is contained in:
@@ -132,6 +132,26 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
self._decode_use_source_tokenizer = False
|
self._decode_use_source_tokenizer = False
|
||||||
|
|
||||||
|
_truncation = self._tokenizer.truncation
|
||||||
|
|
||||||
|
if _truncation is not None:
|
||||||
|
self._tokenizer.enable_truncation(**_truncation)
|
||||||
|
kwargs.setdefault("max_length", _truncation["max_length"])
|
||||||
|
kwargs.setdefault("truncation_side", _truncation["direction"])
|
||||||
|
kwargs.setdefault("stride", _truncation["stride"])
|
||||||
|
kwargs.setdefault("truncation_strategy", _truncation["strategy"])
|
||||||
|
else:
|
||||||
|
self._tokenizer.no_truncation()
|
||||||
|
|
||||||
|
_padding = self._tokenizer.padding
|
||||||
|
if _padding is not None:
|
||||||
|
self._tokenizer.enable_padding(**_padding)
|
||||||
|
kwargs.setdefault("pad_token", _padding["pad_token"])
|
||||||
|
kwargs.setdefault("pad_token_type_id", _padding["pad_type_id"])
|
||||||
|
kwargs.setdefault("padding_side", _padding["direction"])
|
||||||
|
kwargs.setdefault("max_length", _padding["length"])
|
||||||
|
kwargs.setdefault("pad_to_multiple_of", _padding["pad_to_multiple_of"])
|
||||||
|
|
||||||
# We call this after having initialized the backend tokenizer because we update it.
|
# We call this after having initialized the backend tokenizer because we update it.
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
|||||||
@@ -109,6 +109,58 @@ class PreTrainedTokenizationFastTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
encoding_ids = new_tokenizer.encode("a🤗")
|
encoding_ids = new_tokenizer.encode("a🤗")
|
||||||
self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
|
self.assertEqual(encoding_ids, [64, 172, 253, 97, 245])
|
||||||
|
|
||||||
|
def test_init_from_tokenizers_model(self):
|
||||||
|
from tokenizers import Tokenizer
|
||||||
|
|
||||||
|
sentences = ["Hello, y'all!", "How are you 😁 ? There should not be any issue right?"]
|
||||||
|
|
||||||
|
tokenizer = Tokenizer.from_pretrained("t5-base")
|
||||||
|
# Enable padding
|
||||||
|
tokenizer.enable_padding(pad_id=0, pad_token="<pad>", length=512, pad_to_multiple_of=8)
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.padding,
|
||||||
|
{
|
||||||
|
"length": 512,
|
||||||
|
"pad_to_multiple_of": 8,
|
||||||
|
"pad_id": 0,
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
"pad_type_id": 0,
|
||||||
|
"direction": "right",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
fast_tokenizer.save_pretrained(tmpdirname)
|
||||||
|
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
|
||||||
|
for tok in [fast_tokenizer, fast_from_saved]:
|
||||||
|
self.assertEqual(tok.pad_token_id, 0)
|
||||||
|
self.assertEqual(tok.padding_side, "right")
|
||||||
|
self.assertEqual(tok.pad_token, "<pad>")
|
||||||
|
self.assertEqual(tok.init_kwargs["max_length"], 512)
|
||||||
|
self.assertEqual(tok.init_kwargs["pad_to_multiple_of"], 8)
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(tok(sentences, padding = True), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1, 0, 0, 0, 0,0, 0, 0, 0],[ 571, 33, 25, 3, 2, 3, 58, 290, 225, 59, 36, 136, 962, 269, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0],[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]})
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
tokenizer.enable_truncation(8, stride=0, strategy="longest_first", direction="right")
|
||||||
|
self.assertEqual(
|
||||||
|
tokenizer.truncation, {"max_length": 8, "stride": 0, "strategy": "longest_first", "direction": "right"}
|
||||||
|
)
|
||||||
|
fast_tokenizer = PreTrainedTokenizerFast(tokenizer_object=tokenizer)
|
||||||
|
tmpdirname = tempfile.mkdtemp()
|
||||||
|
fast_tokenizer.save_pretrained(tmpdirname)
|
||||||
|
fast_from_saved = PreTrainedTokenizerFast.from_pretrained(tmpdirname)
|
||||||
|
for tok in [fast_tokenizer, fast_from_saved]:
|
||||||
|
self.assertEqual(tok.truncation_side, "right")
|
||||||
|
self.assertEqual(tok.init_kwargs["truncation_strategy"], "longest_first")
|
||||||
|
self.assertEqual(tok.init_kwargs["max_length"], 8)
|
||||||
|
self.assertEqual(tok.init_kwargs["stride"], 0)
|
||||||
|
# NOTE even if the model has a default max_length, it is not used...
|
||||||
|
# thus tok(sentences, truncation = True) does nothing and does not warn either
|
||||||
|
# fmt: off
|
||||||
|
self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]})
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@require_tokenizers
|
@require_tokenizers
|
||||||
class TokenizerVersioningTest(unittest.TestCase):
|
class TokenizerVersioningTest(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user