Disambiguate test for required_input in tokenization base file. (#20731)

* Disambiguate test for required_input in tokenization base file.

* Add test for size
This commit is contained in:
Sylvain Gugger
2022-12-12 13:13:09 -05:00
committed by GitHub
parent 29ff8716a2
commit a450789d9a

View File

@@ -24,7 +24,7 @@ import os
import re import re
import warnings import warnings
from collections import OrderedDict, UserDict from collections import OrderedDict, UserDict
from collections.abc import Mapping from collections.abc import Mapping, Sized
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
@@ -2940,7 +2940,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
required_input = encoded_inputs[self.model_input_names[0]] required_input = encoded_inputs[self.model_input_names[0]]
if not required_input: if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
if return_attention_mask: if return_attention_mask:
encoded_inputs["attention_mask"] = [] encoded_inputs["attention_mask"] = []
return encoded_inputs return encoded_inputs