From c0eb218a55bc9a2b6f5c98ddfb8d3dd47c5ad7cb Mon Sep 17 00:00:00 2001 From: Hamel Husain Date: Wed, 28 Apr 2021 07:11:17 -0700 Subject: [PATCH] Update `PreTrainedTokenizerBase` to check/handle batch length for `text_pair` parameter (#11486) * Update tokenization_utils_base.py * add assertion * check batch len * Update src/transformers/tokenization_utils_base.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * add error message Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- src/transformers/tokenization_utils_base.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 2f160a881b..2d7f7d8518 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2279,6 +2279,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): ) if is_batched: + if isinstance(text_pair, str): + raise TypeError( + "when tokenizing batches of text, `text_pair` must be a list or tuple with the same length as `text`." + ) + if text_pair is not None and len(text) != len(text_pair): + raise ValueError( + f"batch length of `text`: {len(text)} does not match batch length of `text_pair`: {len(text_pair)}." + ) batch_text_or_text_pairs = list(zip(text, text_pair)) if text_pair is not None else text return self.batch_encode_plus( batch_text_or_text_pairs=batch_text_or_text_pairs,