Use None to detect if truncation was unset (#19794)

* Use None to detect if truncation was unset

* Fix repo consistency
This commit is contained in:
Sylvain Gugger
2022-10-21 12:53:37 -04:00
committed by GitHub
parent 2e5c6f5975
commit c4a997cd85
16 changed files with 57 additions and 55 deletions

View File

@@ -2235,7 +2235,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
return_tensors: Optional[Union[str, TensorType]] = None,
@@ -2274,7 +2274,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
raise NotImplementedError
def _get_padding_truncation_strategies(
self, padding=False, truncation=False, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
self, padding=False, truncation=None, max_length=None, pad_to_multiple_of=None, verbose=True, **kwargs
):
"""
Find the correct padding/truncation strategy with backward compatibility for old arguments (truncation_strategy
@@ -2285,7 +2285,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# Backward compatibility for previous behavior, maybe we should deprecate it:
# If you only set max_length, it activates truncation for max_length
if max_length is not None and padding is False and truncation is False:
if max_length is not None and padding is False and truncation is None:
if verbose:
if not self.deprecation_warnings.get("Truncation-not-explicitly-activated", False):
logger.warning(
@@ -2316,7 +2316,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
elif padding is not False:
if padding is True:
if verbose:
if max_length is not None and (truncation is False or truncation == "do_not_truncate"):
if max_length is not None and (
truncation is None or truncation is False or truncation == "do_not_truncate"
):
warnings.warn(
"`max_length` is ignored when `padding`=`True` and there is no truncation strategy. "
"To pad to max length, use `padding='max_length'`."
@@ -2332,7 +2334,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
padding_strategy = PaddingStrategy.DO_NOT_PAD
# Get truncation strategy
if truncation is False and old_truncation_strategy != "do_not_truncate":
if truncation is None and old_truncation_strategy != "do_not_truncate":
if verbose:
warnings.warn(
"The `truncation_strategy` argument is deprecated and will be removed in a future version, use"
@@ -2346,7 +2348,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
FutureWarning,
)
truncation_strategy = TruncationStrategy(old_truncation_strategy)
elif truncation is not False:
elif truncation is not False and truncation is not None:
if truncation is True:
truncation_strategy = (
TruncationStrategy.LONGEST_FIRST
@@ -2420,7 +2422,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
@@ -2504,7 +2506,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
text_pair: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
@@ -2617,7 +2619,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
text_pair: Optional[Union[TextInput, PreTokenizedInput, EncodedInput]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
@@ -2719,7 +2721,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
],
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
is_split_into_words: bool = False,
@@ -3029,7 +3031,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
pair_ids: Optional[List[int]] = None,
add_special_tokens: bool = True,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: Optional[int] = None,
stride: int = 0,
pad_to_multiple_of: Optional[int] = None,