Disambiguate test for required_input in tokenization base file. (#20731)
* Disambiguate test for required_input in tokenization base file. * Add test for size
This commit is contained in:
@@ -24,7 +24,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import warnings
|
import warnings
|
||||||
from collections import OrderedDict, UserDict
|
from collections import OrderedDict, UserDict
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping, Sized
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Sequence, Tuple, Union
|
||||||
@@ -2940,7 +2940,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
|||||||
|
|
||||||
required_input = encoded_inputs[self.model_input_names[0]]
|
required_input = encoded_inputs[self.model_input_names[0]]
|
||||||
|
|
||||||
if not required_input:
|
if required_input is None or (isinstance(required_input, Sized) and len(required_input) == 0):
|
||||||
if return_attention_mask:
|
if return_attention_mask:
|
||||||
encoded_inputs["attention_mask"] = []
|
encoded_inputs["attention_mask"] = []
|
||||||
return encoded_inputs
|
return encoded_inputs
|
||||||
|
|||||||
Reference in New Issue
Block a user