T5Tokenizer adds EOS token if not already added (#5866)
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -18,6 +18,7 @@
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -148,6 +149,74 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer ``prepare_for_model`` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of ids.
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Set to True if the token list is already formatted with special tokens for the model
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: A list of integers in the range [0, 1], 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
if token_ids_1 is not None:
|
||||
raise ValueError(
|
||||
"You should not supply a second sequence if the provided sequence of "
|
||||
"ids is already formatted with special tokens for the model."
|
||||
)
|
||||
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
|
||||
# normal case: some special tokens
|
||||
if token_ids_1 is None:
|
||||
return ([0] * len(token_ids_0)) + [1]
|
||||
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
||||
"""Do not add eos again if user already added it."""
|
||||
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
||||
warnings.warn(
|
||||
f"This sequence already has {self.eos_token}. In future versions this behavior may lead to duplicated eos tokens being added."
|
||||
)
|
||||
return token_ids
|
||||
else:
|
||||
return token_ids + [self.eos_token_id]
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
For some t5 tasks, model.config.prefix is specified. This must be used before tokenization.
|
||||
A sequence has the following format:
|
||||
|
||||
- single sequence: ``X </s>``
|
||||
- pair of sequences: ``A </s> B </s>``
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
token_ids_0 = self._add_eos_if_not_present(token_ids_0)
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0
|
||||
else:
|
||||
token_ids_1 = self._add_eos_if_not_present(token_ids_1)
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
state["sp_model"] = None
|
||||
@@ -210,31 +279,6 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
|
||||
A T5 sequence has the following format, where ``X`` represents the sequence:
|
||||
- ``input_ids`` (for encoder) ``X [eos]``
|
||||
- ``decoder_input_ids``: (for decoder) ``[pad] X [eos]``
|
||||
Pairs of sequences are not the expected use case, but they will be handled without a separator.
|
||||
|
||||
Args:
|
||||
token_ids_0 (:obj:`List[int]`):
|
||||
List of IDs to which the special tokens will be added
|
||||
token_ids_1 (:obj:`List[int]`, `optional`):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return self.prefix_tokens + token_ids_0
|
||||
# We don't expect to process pairs, but leave the pair logic for API consistency
|
||||
return self.prefix_tokens + token_ids_0 + token_ids_1
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -18,6 +18,7 @@ import os
|
||||
import unittest
|
||||
|
||||
from transformers import BatchEncoding
|
||||
from transformers.file_utils import cached_property
|
||||
from transformers.testing_utils import _torch_available
|
||||
from transformers.tokenization_t5 import T5Tokenizer
|
||||
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
||||
@@ -107,28 +108,37 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
],
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def t5_base_tokenizer(self):
|
||||
return T5Tokenizer.from_pretrained("t5-base")
|
||||
|
||||
def test_eos_treatment(self):
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
|
||||
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
|
||||
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
|
||||
|
||||
def test_prepare_seq2seq_batch(self):
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5]
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
|
||||
)
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, tokenizer.eos_token_id]
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, tgt_texts=tgt_text, return_tensors=FRAMEWORK,)
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
self.assertEqual((2, 9), batch.input_ids.shape)
|
||||
self.assertEqual((2, 9), batch.attention_mask.shape)
|
||||
result = list(batch.input_ids.numpy()[0])
|
||||
self.assertListEqual(expected_src_tokens, result)
|
||||
|
||||
self.assertEqual((2, 10), batch.input_ids.shape)
|
||||
self.assertEqual((2, 10), batch.attention_mask.shape)
|
||||
|
||||
# Test that special tokens are reset
|
||||
self.assertEqual(tokenizer.prefix_tokens, [])
|
||||
|
||||
def test_empty_target_text(self):
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
|
||||
# check if input_ids are returned and no decoder_input_ids
|
||||
@@ -138,7 +148,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
def test_max_target_length(self):
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
@@ -158,7 +168,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(32, batch["decoder_attention_mask"].shape[1])
|
||||
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
|
||||
batch = tokenizer.prepare_seq2seq_batch(
|
||||
["I am a small frog" * 1000, "I am a small frog"], return_tensors=FRAMEWORK
|
||||
@@ -167,7 +177,7 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(batch.input_ids.shape, (2, 512))
|
||||
|
||||
def test_eos_in_input(self):
|
||||
tokenizer = T5Tokenizer.from_pretrained("t5-small")
|
||||
tokenizer = self.t5_base_tokenizer
|
||||
src_text = ["A long paragraph for summrization. </s>"]
|
||||
tgt_text = ["Summary of the text. </s>"]
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
|
||||
|
||||
Reference in New Issue
Block a user