Exposing prepare_for_model for both slow & fast tokenizers (#5479)
* Exposing prepare_for_model for both slow & fast tokenizers * Update method signature * The traditional style commit * Hide the warnings behind the verbose flag * update default truncation strategy and prepare_for_model * fix tests and prepare_for_models methods Co-authored-by: Thomas Wolf <thomwolf@users.noreply.github.com>
This commit is contained in:
@@ -945,9 +945,9 @@ ENCODE_KWARGS_DOCSTRING = r"""
|
||||
`truncation` (:obj:`Union[bool, str]`, `optional`, defaults to :obj:`False`):
|
||||
Activate and control truncation. Accepts the following values:
|
||||
|
||||
* `True` or `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `True` or `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'only_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the first sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'only_second'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will only truncate the second sequence of a pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `'longest_first'`: truncate to a max length specified in `max_length` or to the max acceptable input length for the model if no length is provided (`max_length=None`). This will truncate token by token, removing a token from the longest sequence in the pair if a pair of sequences (or a batch of pairs) is provided,
|
||||
* `False` or `'do_not_truncate'` (default): No truncation (i.e. can output batch with sequences length greater than the model max admissible input size)
|
||||
`max_length` (:obj:`Union[int, None]`, `optional`, defaults to :obj:`None`):
|
||||
Control the length for padding/truncation. Accepts the following values
|
||||
@@ -1446,10 +1446,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
logger.warning(
|
||||
"Truncation was not explicitely activated but `max_length` is provided a specific value, "
|
||||
"please use `truncation=True` to explicitely truncate examples to max length. "
|
||||
"Defaulting to 'only_first' truncation strategy. "
|
||||
"If you encode pairs of sequences (GLUE-style) with the tokenizer you may want to check this is the right behavior."
|
||||
"Defaulting to 'longest_first' truncation strategy. "
|
||||
"If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy "
|
||||
"more precisely by providing a specific strategy to `truncation`."
|
||||
)
|
||||
truncation = "only_first"
|
||||
truncation = "longest_first"
|
||||
|
||||
# Get padding strategy
|
||||
if padding is False and old_pad_to_max_length:
|
||||
@@ -1469,7 +1470,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
elif padding is not False:
|
||||
if padding is True:
|
||||
padding_strategy = PaddingStrategy.LONGEST # Default to pad to the longest sequence in the batch
|
||||
else:
|
||||
elif not isinstance(padding, PaddingStrategy):
|
||||
padding_strategy = PaddingStrategy(padding)
|
||||
else:
|
||||
padding_strategy = PaddingStrategy.DO_NOT_PAD
|
||||
@@ -1492,9 +1493,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
elif truncation is not False:
|
||||
if truncation is True:
|
||||
truncation_strategy = (
|
||||
TruncationStrategy.ONLY_FIRST
|
||||
) # Default to truncate the first sequences in pairs of inputs
|
||||
else:
|
||||
TruncationStrategy.LONGEST_FIRST
|
||||
) # Default to truncate the longest sequences in pairs of inputs
|
||||
elif not isinstance(truncation, TruncationStrategy):
|
||||
truncation_strategy = TruncationStrategy(truncation)
|
||||
else:
|
||||
truncation_strategy = TruncationStrategy.DO_NOT_TRUNCATE
|
||||
@@ -1960,6 +1961,225 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
||||
|
||||
return BatchEncoding(batch_outputs, tensor_type=return_tensors)
|
||||
|
||||
def create_token_type_ids_from_sequences(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List[int]:
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0) * [0]
|
||||
return [0] * len(token_ids_0) + [1] * len(token_ids_1)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
|
||||
"""
|
||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens. This implementation does not add special tokens.
|
||||
"""
|
||||
if token_ids_1 is None:
|
||||
return token_ids_0
|
||||
return token_ids_0 + token_ids_1
|
||||
|
||||
@add_end_docstrings(ENCODE_KWARGS_DOCSTRING, ENCODE_PLUS_ADDITIONAL_KWARGS_DOCSTRING)
|
||||
def prepare_for_model(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
add_special_tokens: bool = True,
|
||||
padding: Union[bool, str] = False,
|
||||
truncation: Union[bool, str] = False,
|
||||
max_length: Optional[int] = None,
|
||||
stride: int = 0,
|
||||
pad_to_multiple_of: Optional[int] = None,
|
||||
return_tensors: Optional[Union[str, TensorType]] = None,
|
||||
return_token_type_ids: Optional[bool] = None,
|
||||
return_attention_mask: Optional[bool] = None,
|
||||
return_overflowing_tokens: bool = False,
|
||||
return_special_tokens_mask: bool = False,
|
||||
return_offsets_mapping: bool = False,
|
||||
return_length: bool = False,
|
||||
verbose: bool = True,
|
||||
prepend_batch_axis: bool = False,
|
||||
**kwargs
|
||||
) -> BatchEncoding:
|
||||
""" Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates sequences if overflowing while taking into account the special tokens and
|
||||
manages a moving window (with user defined stride) for overflowing tokens
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
"""
|
||||
|
||||
if "return_lengths" in kwargs:
|
||||
if verbose:
|
||||
warnings.warn(
|
||||
"The PreTrainedTokenizerBase.prepare_for_model `return_lengths` parameter is deprecated. "
|
||||
"Please use `return_length` instead.",
|
||||
FutureWarning,
|
||||
)
|
||||
return_length = kwargs["return_lengths"]
|
||||
|
||||
# Backward compatibility for 'truncation_strategy', 'pad_to_max_length'
|
||||
padding_strategy, truncation_strategy, max_length, kwargs = self._get_padding_truncation_strategies(
|
||||
padding=padding,
|
||||
truncation=truncation,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
verbose=verbose,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
pair = bool(pair_ids is not None)
|
||||
len_ids = len(ids)
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
# Load from model defaults
|
||||
if return_token_type_ids is None:
|
||||
return_token_type_ids = "token_type_ids" in self.model_input_names
|
||||
if return_attention_mask is None:
|
||||
return_attention_mask = "attention_mask" in self.model_input_names
|
||||
|
||||
encoded_inputs = {}
|
||||
|
||||
# Compute the total size of the returned encodings
|
||||
total_len = len_ids + len_pair_ids + (self.num_special_tokens_to_add(pair=pair) if add_special_tokens else 0)
|
||||
|
||||
# Truncation: Handle max sequence length
|
||||
if truncation_strategy != TruncationStrategy.DO_NOT_TRUNCATE and max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(
|
||||
ids,
|
||||
pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len - max_length,
|
||||
truncation_strategy=truncation_strategy,
|
||||
stride=stride,
|
||||
)
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
# Add special tokens
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
|
||||
# Build output dictionnary
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
if return_special_tokens_mask:
|
||||
if add_special_tokens:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
encoded_inputs["special_tokens_mask"] = [0] * len(sequence)
|
||||
|
||||
# Check lengths
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.model_max_length and verbose:
|
||||
logger.warning(
|
||||
"Token indices sequence length is longer than the specified maximum sequence length "
|
||||
"for this model ({} > {}). Running this sequence through the model will result in "
|
||||
"indexing errors".format(len(ids), self.model_max_length)
|
||||
)
|
||||
|
||||
# Padding
|
||||
if padding_strategy != PaddingStrategy.DO_NOT_PAD or return_attention_mask:
|
||||
encoded_inputs = self.pad(
|
||||
encoded_inputs,
|
||||
max_length=max_length,
|
||||
padding=padding_strategy.value,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=return_attention_mask,
|
||||
)
|
||||
|
||||
if return_length:
|
||||
encoded_inputs["length"] = len(encoded_inputs["input_ids"])
|
||||
|
||||
batch_outputs = BatchEncoding(
|
||||
encoded_inputs, tensor_type=return_tensors, prepend_batch_axis=prepend_batch_axis
|
||||
)
|
||||
|
||||
return batch_outputs
|
||||
|
||||
def truncate_sequences(
|
||||
self,
|
||||
ids: List[int],
|
||||
pair_ids: Optional[List[int]] = None,
|
||||
num_tokens_to_remove: int = 0,
|
||||
truncation_strategy: Union[str, TruncationStrategy] = "longest_first",
|
||||
stride: int = 0,
|
||||
) -> Tuple[List[int], List[int], List[int]]:
|
||||
""" Truncates a sequence pair in place to the maximum length.
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
pair_ids: Optional second list of input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
num_tokens_to_remove (:obj:`int`, `optional`, defaults to ``0``):
|
||||
number of tokens to remove using the truncation strategy
|
||||
truncation_strategy (:obj:`string`, `optional`, defaults to "longest_first"):
|
||||
String selected in the following options:
|
||||
|
||||
- 'longest_first' (default): Iteratively reduce the inputs sequence until the input is under max_length
|
||||
starting from the longest one at each token (when there is a pair of input sequences).
|
||||
Overflowing tokens only contains overflow from the first sequence.
|
||||
- 'only_first': Only truncate the first sequence. raise an error if the first sequence is shorter or equal to than num_tokens_to_remove.
|
||||
- 'only_second': Only truncate the second sequence
|
||||
- 'do_not_truncate'
|
||||
stride (:obj:`int`, `optional`, defaults to ``0``):
|
||||
If set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defines the number of additional tokens.
|
||||
"""
|
||||
if num_tokens_to_remove <= 0:
|
||||
return ids, pair_ids, []
|
||||
|
||||
if not isinstance(truncation_strategy, TruncationStrategy):
|
||||
truncation_strategy = TruncationStrategy(truncation_strategy)
|
||||
|
||||
overflowing_tokens = []
|
||||
if truncation_strategy == TruncationStrategy.LONGEST_FIRST:
|
||||
for _ in range(num_tokens_to_remove):
|
||||
if pair_ids is None or len(ids) > len(pair_ids):
|
||||
if not overflowing_tokens:
|
||||
window_len = min(len(ids), stride + 1)
|
||||
else:
|
||||
window_len = 1
|
||||
overflowing_tokens.extend(ids[-window_len:])
|
||||
ids = ids[:-1]
|
||||
else:
|
||||
if not overflowing_tokens:
|
||||
window_len = min(len(pair_ids), stride + 1)
|
||||
else:
|
||||
window_len = 1
|
||||
overflowing_tokens.extend(pair_ids[-window_len:])
|
||||
pair_ids = pair_ids[:-1]
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_FIRST:
|
||||
if len(ids) > num_tokens_to_remove:
|
||||
window_len = min(len(ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = ids[-window_len:]
|
||||
ids = ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the first sequence has a length {len(ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_second'."
|
||||
)
|
||||
elif truncation_strategy == TruncationStrategy.ONLY_SECOND and pair_ids is not None:
|
||||
if len(pair_ids) > num_tokens_to_remove:
|
||||
window_len = min(len(pair_ids), stride + num_tokens_to_remove)
|
||||
overflowing_tokens = pair_ids[-window_len:]
|
||||
pair_ids = pair_ids[:-num_tokens_to_remove]
|
||||
else:
|
||||
logger.error(
|
||||
f"We need to remove {num_tokens_to_remove} to truncate the input"
|
||||
f"but the second sequence has a length {len(pair_ids)}. "
|
||||
f"Please select another truncation strategy than {truncation_strategy}, "
|
||||
f"for instance 'longest_first' or 'only_first'."
|
||||
)
|
||||
|
||||
return (ids, pair_ids, overflowing_tokens)
|
||||
|
||||
def _pad(
|
||||
self,
|
||||
encoded_inputs: Union[Dict[str, EncodedInput], BatchEncoding],
|
||||
|
||||
Reference in New Issue
Block a user