diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 3648f5e99b..4330ae0a36 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -1781,6 +1781,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch elif not isinstance(padding, PaddingStrategy): padding_strategy = PaddingStrategy(padding) + elif isinstance(padding, PaddingStrategy): + padding_strategy = padding else: padding_strategy = PaddingStrategy.DO_NOT_PAD @@ -1806,6 +1808,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin): ) # Default to truncate the longest sequences in pairs of inputs elif not isinstance(truncation, TruncationStrategy): truncation_strategy = TruncationStrategy(truncation) + elif isinstance(truncation, TruncationStrategy): + truncation_strategy = truncation else: truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE