Add correct batched handling for apply_chat_template (#29222)

* Add correct batched handling for apply_chat_template

* Fix warning method

* Add error for incompatible options

* expand tests

* Add a skip for markuplm

* Add skips for other layout models

* Skip for LayoutLMv2

* Slightly update the warning message

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* typo fix

* Update docstring for conversation kwarg

* Update return docstring

* Remove the warning, improve error message

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/tokenization_utils_base.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/test_tokenization_common.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Remove return_dict=None

* Fix up some merge cruft

* More merge cruft

* Add another skip

* Add another skip

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Matt
2024-03-20 15:50:22 +00:00
committed by GitHub
parent 3c17c529cc
commit 9d999481b2
8 changed files with 127 additions and 40 deletions

View File

@@ -195,6 +195,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.tokenize(" \tHeLLo!how \n Are yoU? [UNK]"), ["HeLLo", "!", "how", "Are", "yoU", "?", "[UNK]"]
)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_wordpiece_tokenizer(self):
vocab_tokens = ["[UNK]", "[CLS]", "[SEP]", "want", "##want", "##ed", "wa", "un", "runn", "##ing"]

View File

@@ -140,6 +140,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "lower newer"
return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_full_tokenizer(self):
tokenizer = self.tokenizer_class(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower newer"

View File

@@ -107,6 +107,10 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
output_text = "unwanted, running"
return input_text, output_text
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
# override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of
# this tokenizer
def test_save_sentencepiece_tokenizer(self) -> None:

View File

@@ -101,6 +101,10 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
return questions, nodes, xpaths
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def get_input_output_texts(self, tokenizer):
input_text = "UNwant\u00E9d,running"
output_text = "unwanted, running"

View File

@@ -223,6 +223,10 @@ class TapasTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
def test_chinese(self):
tokenizer = BasicTokenizer()

View File

@@ -1153,6 +1153,14 @@ class UdopTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# Assert there is online added_tokens special_tokens
self.assertEqual(sum(tokens_with_offsets["special_tokens_mask"]), added_tokens)
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template(self):
pass
@unittest.skip("Chat template tests don't play well with table/layout models.")
def test_chat_template_batched(self):
pass
@require_torch
@slow
def test_torch_encode_plus_sent_to_model(self):