Replace as_target context managers by direct calls (#18325)
* Preliminary work on tokenizers * Quality + fix tests * Treat processors * Fix pad * Remove all uses of in tests, docs and examples * Replace all as_target_tokenizer * Fix tests * Fix quality * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: amyeroberts <amy@huggingface.co> * Style Co-authored-by: amyeroberts <amy@huggingface.co>
This commit is contained in:
@@ -112,14 +112,13 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
def test_tokenizer_as_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
@@ -140,8 +139,7 @@ class TestTokenizationBart(TokenizerTesterMixin, unittest.TestCase):
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
|
||||
@@ -152,10 +152,9 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
targets = tokenizer(
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
def test_eos_in_input(self):
|
||||
@@ -167,12 +166,10 @@ class ByT5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_tgt_tokens = [86, 120, 112, 112, 100, 117, 124, 35, 114, 105, 35, 119, 107, 104, 35, 119, 104, 123, 119, 49, 35, 1]
|
||||
# fmt: on
|
||||
|
||||
batch = tokenizer(src_text)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text)
|
||||
batch = tokenizer(src_text, text_target=tgt_text)
|
||||
|
||||
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
|
||||
|
||||
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
|
||||
def test_save_and_load_tokenizer(self):
|
||||
|
||||
@@ -80,8 +80,9 @@ class CanineTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"What's the weater?",
|
||||
"It's about 25 degrees.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt")
|
||||
targets = tokenizer(
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from transformers import (
|
||||
DPRContextEncoderTokenizer,
|
||||
DPRContextEncoderTokenizerFast,
|
||||
|
||||
@@ -187,9 +187,7 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.tokenizer.src_lang = "en"
|
||||
self.tokenizer.tgt_lang = "fr"
|
||||
|
||||
batch = self.tokenizer(self.src_text, padding=True, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
batch["labels"] = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt").input_ids
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.eos_token_id
|
||||
@@ -217,17 +215,19 @@ class M2M100TokenizerIntegrationTest(unittest.TestCase):
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer(self):
|
||||
def test_tokenizer_target_mode(self):
|
||||
self.tokenizer.tgt_lang = "mr"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("mr")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
self.tokenizer.tgt_lang = "zh"
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_target_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id("zh")])
|
||||
self.assertListEqual(self.tokenizer.suffix_tokens, [self.tokenizer.eos_token_id])
|
||||
self.tokenizer._switch_to_input_mode()
|
||||
self.assertListEqual(self.tokenizer.prefix_tokens, [self.tokenizer.get_lang_id(self.tokenizer.src_lang)])
|
||||
|
||||
@require_torch
|
||||
|
||||
@@ -438,10 +438,7 @@ class TestMarian_EN_DE_More(MarianIntegrationTest):
|
||||
src, tgt = ["I am a small frog"], ["Ich bin ein kleiner Frosch."]
|
||||
expected_ids = [38, 121, 14, 697, 38848, 0]
|
||||
|
||||
model_inputs = self.tokenizer(src, return_tensors="pt").to(torch_device)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(tgt, return_tensors="pt")
|
||||
model_inputs["labels"] = targets["input_ids"].to(torch_device)
|
||||
model_inputs = self.tokenizer(src, text_target=tgt, return_tensors="pt").to(torch_device)
|
||||
|
||||
self.assertListEqual(expected_ids, model_inputs.input_ids[0].tolist())
|
||||
|
||||
|
||||
@@ -145,9 +145,8 @@ class MarianTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
src_ids = tokenizer(source_text).input_ids
|
||||
self.assertListEqual(src_ids, expected_src_ids)
|
||||
|
||||
with tokenizer.as_target_tokenizer():
|
||||
target_ids = tokenizer(target_text).input_ids
|
||||
self.assertListEqual(target_ids, expected_target_ids)
|
||||
target_ids = tokenizer(text_target=target_text).input_ids
|
||||
self.assertListEqual(target_ids, expected_target_ids)
|
||||
|
||||
decoded = tokenizer.decode(target_ids, skip_special_tokens=True)
|
||||
self.assertEqual(decoded, target_text)
|
||||
|
||||
@@ -265,33 +265,27 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
assert batch.input_ids[1][-2:] == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0] == RO_CODE
|
||||
assert batch.input_ids[1][-2:].tolist() == [2, EN_CODE]
|
||||
assert batch.decoder_input_ids[1][0].tolist() == RO_CODE
|
||||
assert batch.decoder_input_ids[1][-1] == 2
|
||||
assert labels[1][-2:].tolist() == [2, RO_CODE]
|
||||
assert batch.labels[1][-2:].tolist() == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
@@ -306,8 +300,9 @@ class MBartEnroIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
||||
@@ -256,35 +256,27 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
labels = labels.tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
assert batch.input_ids[1][0] == EN_CODE
|
||||
assert batch.input_ids[1][-1] == 2
|
||||
assert labels[1][0] == RO_CODE
|
||||
assert labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2] == [2, RO_CODE]
|
||||
assert batch.labels[1][0] == RO_CODE
|
||||
assert batch.labels[1][-1] == 2
|
||||
assert batch.decoder_input_ids[1][:2].tolist() == [2, RO_CODE]
|
||||
|
||||
@require_torch
|
||||
def test_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
@@ -299,8 +291,9 @@ class MBart50OneToManyIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_target_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
||||
@@ -125,8 +125,7 @@ class MCTCTProcessorTest(unittest.TestCase):
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
||||
@@ -112,14 +112,13 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertNotIn("decoder_attention_mask", batch)
|
||||
|
||||
@require_torch
|
||||
def test_as_target_tokenizer_target_length(self):
|
||||
def test_tokenizer_as_target_length(self):
|
||||
tgt_text = [
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
targets = tokenizer(text_target=tgt_text, max_length=32, padding="max_length", return_tensors="pt")
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
@require_torch
|
||||
@@ -139,11 +138,9 @@ class TestTokenizationMvp(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Summary of the text.",
|
||||
]
|
||||
for tokenizer in [self.default_tokenizer, self.default_tokenizer_fast]:
|
||||
inputs = tokenizer(src_text, return_tensors="pt")
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text, return_tensors="pt")
|
||||
inputs = tokenizer(src_text, text_target=tgt_text, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"]
|
||||
labels = targets["input_ids"]
|
||||
labels = inputs["labels"]
|
||||
self.assertTrue((input_ids[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((labels[:, 0] == tokenizer.bos_token_id).all().item())
|
||||
self.assertTrue((input_ids[:, -1] == tokenizer.eos_token_id).all().item())
|
||||
|
||||
@@ -373,19 +373,15 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
@require_torch
|
||||
def test_enro_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
labels, self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
|
||||
batch["labels"], self.tokenizer.pad_token_id, self.tokenizer.lang_code_to_id["ron_Latn"]
|
||||
)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
@@ -401,8 +397,9 @@ class NllbDistilledIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(
|
||||
labels,
|
||||
|
||||
@@ -109,10 +109,9 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
src_texts = ["This is going to be way too long." * 150, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
targets = self._large_tokenizer(
|
||||
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 1024)
|
||||
assert batch.attention_mask.shape == (2, 1024)
|
||||
@@ -174,10 +173,9 @@ class BigBirdPegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
src_texts = ["This is going to be way too long." * 1000, "short example"]
|
||||
tgt_texts = ["not super long but more than 5 tokens", "tiny"]
|
||||
batch = self._large_tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
|
||||
with self._large_tokenizer.as_target_tokenizer():
|
||||
targets = self._large_tokenizer(
|
||||
tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
targets = self._large_tokenizer(
|
||||
text_target=tgt_texts, max_length=5, padding=True, truncation=True, return_tensors="pt"
|
||||
)
|
||||
|
||||
assert batch.input_ids.shape == (2, 4096)
|
||||
assert batch.attention_mask.shape == (2, 4096)
|
||||
|
||||
@@ -146,10 +146,9 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
targets = tokenizer(
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
# cannot use default save_and_load_tokenzier test method because tokenzier has no vocab
|
||||
|
||||
@@ -299,33 +299,26 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
|
||||
@require_torch
|
||||
def test_batch_fairseq_parity(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, return_tensors="pt")
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id).tolist()
|
||||
batch = self.tokenizer(self.src_text, text_target=self.tgt_text, padding=True, return_tensors="pt")
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
# fairseq batch: https://gist.github.com/sshleifer/cba08bc2109361a74ac3760a7e30e4f4
|
||||
self.assertEqual(batch.input_ids[1][-2:], [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.input_ids[1][-2:].tolist(), [2, PYTHON_CODE])
|
||||
self.assertEqual(batch.decoder_input_ids[1][0], EN_CODE)
|
||||
self.assertEqual(batch.decoder_input_ids[1][-1], 2)
|
||||
self.assertEqual(labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
self.assertEqual(batch.labels[1][-2:].tolist(), [2, EN_CODE])
|
||||
|
||||
@require_torch
|
||||
def test_python_en_tokenizer_prepare_batch(self):
|
||||
batch = self.tokenizer(
|
||||
self.src_text, padding=True, truncation=True, max_length=len(self.expected_src_tokens), return_tensors="pt"
|
||||
self.src_text,
|
||||
text_target=self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(
|
||||
self.tgt_text,
|
||||
padding=True,
|
||||
truncation=True,
|
||||
max_length=len(self.expected_src_tokens),
|
||||
return_tensors="pt",
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
batch["decoder_input_ids"] = shift_tokens_right(batch["labels"], self.tokenizer.pad_token_id)
|
||||
|
||||
self.assertIsInstance(batch, BatchEncoding)
|
||||
|
||||
@@ -340,8 +333,9 @@ class PLBartPythonEnIntegrationTest(unittest.TestCase):
|
||||
|
||||
def test_seq2seq_max_length(self):
|
||||
batch = self.tokenizer(self.src_text, padding=True, truncation=True, max_length=3, return_tensors="pt")
|
||||
with self.tokenizer.as_target_tokenizer():
|
||||
targets = self.tokenizer(self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt")
|
||||
targets = self.tokenizer(
|
||||
text_target=self.tgt_text, padding=True, truncation=True, max_length=10, return_tensors="pt"
|
||||
)
|
||||
labels = targets["input_ids"]
|
||||
batch["decoder_input_ids"] = shift_tokens_right(labels, self.tokenizer.pad_token_id)
|
||||
|
||||
|
||||
@@ -125,8 +125,7 @@ class Speech2TextProcessorTest(unittest.TestCase):
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
||||
@@ -210,10 +210,9 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
"Summary of the text.",
|
||||
"Another summary.",
|
||||
]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(
|
||||
tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
targets = tokenizer(
|
||||
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
|
||||
)
|
||||
self.assertEqual(32, targets["input_ids"].shape[1])
|
||||
|
||||
def test_outputs_not_longer_than_maxlen(self):
|
||||
@@ -235,12 +234,10 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
expected_src_tokens = [71, 307, 8986, 21, 4505, 1635, 1707, 5, 1]
|
||||
expected_tgt_tokens = [20698, 13, 8, 1499, 5, 1]
|
||||
|
||||
batch = tokenizer(src_text)
|
||||
with tokenizer.as_target_tokenizer():
|
||||
targets = tokenizer(tgt_text)
|
||||
batch = tokenizer(src_text, text_target=tgt_text)
|
||||
|
||||
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, targets["input_ids"][0])
|
||||
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
|
||||
|
||||
def test_token_type_ids(self):
|
||||
src_text_1 = ["A first paragraph for summarization."]
|
||||
|
||||
@@ -859,9 +859,8 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
tokenizer = TapexTokenizer.from_pretrained("microsoft/tapex-base")
|
||||
answer_text = "tapex is a good model!"
|
||||
expected_src_tokens = [0, 90, 5776, 1178, 16, 10, 205, 1421, 328, 2]
|
||||
with tokenizer.as_target_tokenizer():
|
||||
answer_encoding = tokenizer(answer=answer_text)
|
||||
self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
|
||||
answer_encoding = tokenizer(answer=answer_text)
|
||||
self.assertListEqual(answer_encoding.input_ids, expected_src_tokens)
|
||||
|
||||
@slow
|
||||
def test_tokenizer_lower_case(self):
|
||||
@@ -870,23 +869,21 @@ class TapexTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
answer_text = "Beijing, London, Paris"
|
||||
answer_text_lower = "beijing, london, paris"
|
||||
|
||||
with cased_tokenizer.as_target_tokenizer():
|
||||
with uncased_tokenizer.as_target_tokenizer():
|
||||
self.assertNotEqual(
|
||||
cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
|
||||
)
|
||||
self.assertEqual(
|
||||
cased_tokenizer(answer=answer_text_lower).input_ids,
|
||||
uncased_tokenizer(answer=answer_text).input_ids,
|
||||
)
|
||||
# batched encoding assert
|
||||
self.assertNotEqual(
|
||||
cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
|
||||
)
|
||||
self.assertEqual(
|
||||
cased_tokenizer(answer=[answer_text_lower]).input_ids,
|
||||
uncased_tokenizer(answer=[answer_text]).input_ids,
|
||||
)
|
||||
self.assertNotEqual(
|
||||
cased_tokenizer(answer=answer_text).input_ids, uncased_tokenizer(answer=answer_text).input_ids
|
||||
)
|
||||
self.assertEqual(
|
||||
cased_tokenizer(answer=answer_text_lower).input_ids,
|
||||
uncased_tokenizer(answer=answer_text).input_ids,
|
||||
)
|
||||
# batched encoding assert
|
||||
self.assertNotEqual(
|
||||
cased_tokenizer(answer=[answer_text]).input_ids, uncased_tokenizer(answer=[answer_text]).input_ids
|
||||
)
|
||||
self.assertEqual(
|
||||
cased_tokenizer(answer=[answer_text_lower]).input_ids,
|
||||
uncased_tokenizer(answer=[answer_text]).input_ids,
|
||||
)
|
||||
# test input encoding lowercase
|
||||
question = "Greece held its last Summer Olympics in 2004"
|
||||
table_dict = {
|
||||
|
||||
@@ -118,8 +118,7 @@ class Wav2Vec2ProcessorTest(unittest.TestCase):
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
||||
@@ -164,8 +164,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
|
||||
|
||||
input_str = "This is a test string"
|
||||
|
||||
with processor.as_target_processor():
|
||||
encoded_processor = processor(input_str)
|
||||
encoded_processor = processor(text=input_str)
|
||||
|
||||
encoded_tok = tokenizer(input_str)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user