From a450789d9adfeb1a323c5a7b5cf2193214dc1d7d Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Mon, 12 Dec 2022 13:13:09 -0500 Subject: [PATCH] Disambiguate test for required_input in tokenization base file. (#20731) * Disambiguate test for required_input in tokenization base file. * Add test for size --- src/transformers/tokenization_utils_base.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 011edfa1e7..7999ed5c91 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -24,7 +24,7 @@ import os import re import warnings from collections import OrderedDict, UserDict -from collections.abc import Mapping +from collections.abc import Mapping, Sized from contextlib import contextmanager from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union @@ -2940,7 +2940,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin): required_input = encoded_inputs[self.model_input_names[0]] - if not required_input: + if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0): if return_attention_mask: encoded_inputs["attention_mask"] = [] return encoded_inputs