Merge pull request #1724 from huggingface/fix_encode_plus
Fix encode_plus
This commit is contained in:
@@ -750,6 +750,9 @@ class PreTrainedTokenizer(object):
|
||||
stride=0,
|
||||
truncation_strategy='longest_first',
|
||||
return_tensors=None,
|
||||
return_token_type_ids=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
|
||||
@@ -776,7 +779,30 @@ class PreTrainedTokenizer(object):
|
||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
|
||||
Return:
|
||||
A Dictionary of shape::
|
||||
|
||||
{
|
||||
input_ids: list[int],
|
||||
token_type_ids: list[int] if return_token_type_ids is True (default)
|
||||
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||
}
|
||||
|
||||
With the fields:
|
||||
``input_ids``: list of token ids to be fed to a model
|
||||
``token_type_ids``: list of token type ids to be fed to a model
|
||||
|
||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||
tokens and 1 specifying sequence tokens.
|
||||
"""
|
||||
|
||||
def get_input_ids(text):
|
||||
@@ -798,10 +824,17 @@ class PreTrainedTokenizer(object):
|
||||
add_special_tokens=add_special_tokens,
|
||||
stride=stride,
|
||||
truncation_strategy=truncation_strategy,
|
||||
return_tensors=return_tensors)
|
||||
return_tensors=return_tensors,
|
||||
return_token_type_ids=return_token_type_ids,
|
||||
return_overflowing_tokens=return_overflowing_tokens,
|
||||
return_special_tokens_mask=return_special_tokens_mask)
|
||||
|
||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
|
||||
truncation_strategy='longest_first', return_tensors=None):
|
||||
truncation_strategy='longest_first',
|
||||
return_tensors=None,
|
||||
return_token_type_ids=True,
|
||||
return_overflowing_tokens=False,
|
||||
return_special_tokens_mask=False):
|
||||
"""
|
||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||
It adds special tokens, truncates
|
||||
@@ -826,21 +859,27 @@ class PreTrainedTokenizer(object):
|
||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||
or PyTorch torch.Tensor instead of a list of python integers.
|
||||
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
|
||||
|
||||
Return:
|
||||
A Dictionary of shape::
|
||||
|
||||
{
|
||||
input_ids: list[int],
|
||||
overflowing_tokens: list[int] if a ``max_length`` is specified, else None
|
||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
|
||||
token_type_ids: list[int] if return_token_type_ids is True (default)
|
||||
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||
}
|
||||
|
||||
With the fields:
|
||||
``input_ids``: list of tokens to be fed to a model
|
||||
``input_ids``: list of token ids to be fed to a model
|
||||
``token_type_ids``: list of token type ids to be fed to a model
|
||||
|
||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||
|
||||
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||
tokens and 1 specifying sequence tokens.
|
||||
"""
|
||||
@@ -849,23 +888,31 @@ class PreTrainedTokenizer(object):
|
||||
len_pair_ids = len(pair_ids) if pair else 0
|
||||
|
||||
encoded_inputs = {}
|
||||
|
||||
# Handle max sequence length
|
||||
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
|
||||
if max_length and total_len > max_length:
|
||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
|
||||
num_tokens_to_remove=total_len-max_length,
|
||||
truncation_strategy=truncation_strategy,
|
||||
stride=stride)
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
if return_overflowing_tokens:
|
||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||
|
||||
# Handle special_tokens
|
||||
if add_special_tokens:
|
||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
|
||||
else:
|
||||
sequence = ids + pair_ids if pair else ids
|
||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||
special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||
|
||||
# Prepare inputs as tensors if asked
|
||||
if return_tensors == 'tf' and is_tf_available():
|
||||
sequence = tf.constant([sequence])
|
||||
token_type_ids = tf.constant([token_type_ids])
|
||||
@@ -876,12 +923,15 @@ class PreTrainedTokenizer(object):
|
||||
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
|
||||
|
||||
encoded_inputs["input_ids"] = sequence
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = token_type_ids
|
||||
|
||||
if max_length and len(encoded_inputs["input_ids"]) > max_length:
|
||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||
if return_token_type_ids:
|
||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||
if return_special_tokens_mask:
|
||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||
|
||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
|
||||
logger.warning("Token indices sequence length is longer than the specified maximum sequence length "
|
||||
|
||||
Reference in New Issue
Block a user