Merge pull request #1724 from huggingface/fix_encode_plus
Fix encode_plus
This commit is contained in:
@@ -273,7 +273,11 @@ class CommonTestCases:
|
|||||||
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
|
sequence = tokenizer.encode(seq_0, add_special_tokens=False)
|
||||||
num_added_tokens = tokenizer.num_added_tokens()
|
num_added_tokens = tokenizer.num_added_tokens()
|
||||||
total_length = len(sequence) + num_added_tokens
|
total_length = len(sequence) + num_added_tokens
|
||||||
information = tokenizer.encode_plus(seq_0, max_length=total_length - 2, add_special_tokens=True, stride=stride)
|
information = tokenizer.encode_plus(seq_0,
|
||||||
|
max_length=total_length - 2,
|
||||||
|
add_special_tokens=True,
|
||||||
|
stride=stride,
|
||||||
|
return_overflowing_tokens=True)
|
||||||
|
|
||||||
truncated_sequence = information["input_ids"]
|
truncated_sequence = information["input_ids"]
|
||||||
overflowing_tokens = information["overflowing_tokens"]
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
@@ -300,10 +304,12 @@ class CommonTestCases:
|
|||||||
)
|
)
|
||||||
|
|
||||||
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
|
information = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2, add_special_tokens=True,
|
||||||
stride=stride, truncation_strategy='only_second')
|
stride=stride, truncation_strategy='only_second',
|
||||||
|
return_overflowing_tokens=True)
|
||||||
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
|
information_first_truncated = tokenizer.encode_plus(seq_0, seq_1, max_length=len(sequence) - 2,
|
||||||
add_special_tokens=True, stride=stride,
|
add_special_tokens=True, stride=stride,
|
||||||
truncation_strategy='only_first')
|
truncation_strategy='only_first',
|
||||||
|
return_overflowing_tokens=True)
|
||||||
|
|
||||||
truncated_sequence = information["input_ids"]
|
truncated_sequence = information["input_ids"]
|
||||||
overflowing_tokens = information["overflowing_tokens"]
|
overflowing_tokens = information["overflowing_tokens"]
|
||||||
@@ -335,7 +341,7 @@ class CommonTestCases:
|
|||||||
|
|
||||||
# Testing single inputs
|
# Testing single inputs
|
||||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
|
||||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True, return_special_tokens_mask=True)
|
||||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||||
@@ -347,7 +353,8 @@ class CommonTestCases:
|
|||||||
# Testing inputs pairs
|
# Testing inputs pairs
|
||||||
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1,
|
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False) + tokenizer.encode(sequence_1,
|
||||||
add_special_tokens=False)
|
add_special_tokens=False)
|
||||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True)
|
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, sequence_1, add_special_tokens=True,
|
||||||
|
return_special_tokens_mask=True)
|
||||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||||
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
|
||||||
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
|
||||||
@@ -359,7 +366,9 @@ class CommonTestCases:
|
|||||||
# Testing with already existing special tokens
|
# Testing with already existing special tokens
|
||||||
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
|
||||||
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
|
tokenizer.add_special_tokens({'cls_token': '</s>', 'sep_token': '<s>'})
|
||||||
encoded_sequence_dict = tokenizer.encode_plus(sequence_0, add_special_tokens=True)
|
encoded_sequence_dict = tokenizer.encode_plus(sequence_0,
|
||||||
|
add_special_tokens=True,
|
||||||
|
return_special_tokens_mask=True)
|
||||||
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
encoded_sequence_w_special = encoded_sequence_dict["input_ids"]
|
||||||
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
special_tokens_mask_orig = encoded_sequence_dict["special_tokens_mask"]
|
||||||
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
|
special_tokens_mask = tokenizer.get_special_tokens_mask(encoded_sequence_w_special, already_has_special_tokens=True)
|
||||||
|
|||||||
@@ -750,6 +750,9 @@ class PreTrainedTokenizer(object):
|
|||||||
stride=0,
|
stride=0,
|
||||||
truncation_strategy='longest_first',
|
truncation_strategy='longest_first',
|
||||||
return_tensors=None,
|
return_tensors=None,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
"""
|
"""
|
||||||
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
|
Returns a dictionary containing the encoded sequence or sequence pair and additional informations:
|
||||||
@@ -776,7 +779,30 @@ class PreTrainedTokenizer(object):
|
|||||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
|
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||||
|
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||||
|
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
|
|
||||||
|
Return:
|
||||||
|
A Dictionary of shape::
|
||||||
|
|
||||||
|
{
|
||||||
|
input_ids: list[int],
|
||||||
|
token_type_ids: list[int] if return_token_type_ids is True (default)
|
||||||
|
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||||
|
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||||
|
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||||
|
}
|
||||||
|
|
||||||
|
With the fields:
|
||||||
|
``input_ids``: list of token ids to be fed to a model
|
||||||
|
``token_type_ids``: list of token type ids to be fed to a model
|
||||||
|
|
||||||
|
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||||
|
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||||
|
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||||
|
tokens and 1 specifying sequence tokens.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_input_ids(text):
|
def get_input_ids(text):
|
||||||
@@ -798,10 +824,17 @@ class PreTrainedTokenizer(object):
|
|||||||
add_special_tokens=add_special_tokens,
|
add_special_tokens=add_special_tokens,
|
||||||
stride=stride,
|
stride=stride,
|
||||||
truncation_strategy=truncation_strategy,
|
truncation_strategy=truncation_strategy,
|
||||||
return_tensors=return_tensors)
|
return_tensors=return_tensors,
|
||||||
|
return_token_type_ids=return_token_type_ids,
|
||||||
|
return_overflowing_tokens=return_overflowing_tokens,
|
||||||
|
return_special_tokens_mask=return_special_tokens_mask)
|
||||||
|
|
||||||
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
|
def prepare_for_model(self, ids, pair_ids=None, max_length=None, add_special_tokens=True, stride=0,
|
||||||
truncation_strategy='longest_first', return_tensors=None):
|
truncation_strategy='longest_first',
|
||||||
|
return_tensors=None,
|
||||||
|
return_token_type_ids=True,
|
||||||
|
return_overflowing_tokens=False,
|
||||||
|
return_special_tokens_mask=False):
|
||||||
"""
|
"""
|
||||||
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
Prepares a sequence of input id, or a pair of sequences of inputs ids so that it can be used by the model.
|
||||||
It adds special tokens, truncates
|
It adds special tokens, truncates
|
||||||
@@ -826,21 +859,27 @@ class PreTrainedTokenizer(object):
|
|||||||
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
- 'do_not_truncate': Does not truncate (raise an error if the input sequence is longer than max_length)
|
||||||
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
return_tensors: (optional) can be set to 'tf' or 'pt' to return respectively TensorFlow tf.constant
|
||||||
or PyTorch torch.Tensor instead of a list of python integers.
|
or PyTorch torch.Tensor instead of a list of python integers.
|
||||||
|
return_token_type_ids: (optional) Set to False to avoid returning token_type_ids (default True).
|
||||||
|
return_overflowing_tokens: (optional) Set to True to return overflowing token information (default False).
|
||||||
|
return_special_tokens_mask: (optional) Set to True to return special tokens mask information (default False).
|
||||||
|
|
||||||
Return:
|
Return:
|
||||||
A Dictionary of shape::
|
A Dictionary of shape::
|
||||||
|
|
||||||
{
|
{
|
||||||
input_ids: list[int],
|
input_ids: list[int],
|
||||||
overflowing_tokens: list[int] if a ``max_length`` is specified, else None
|
token_type_ids: list[int] if return_token_type_ids is True (default)
|
||||||
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True``
|
overflowing_tokens: list[int] if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||||
|
num_truncated_tokens: int if a ``max_length`` is specified and return_overflowing_tokens is True
|
||||||
|
special_tokens_mask: list[int] if ``add_special_tokens`` if set to ``True`` and return_special_tokens_mask is True
|
||||||
}
|
}
|
||||||
|
|
||||||
With the fields:
|
With the fields:
|
||||||
``input_ids``: list of tokens to be fed to a model
|
``input_ids``: list of token ids to be fed to a model
|
||||||
|
``token_type_ids``: list of token type ids to be fed to a model
|
||||||
|
|
||||||
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
``overflowing_tokens``: list of overflowing tokens if a max length is specified.
|
||||||
|
``num_truncated_tokens``: number of overflowing tokens a ``max_length`` is specified
|
||||||
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
``special_tokens_mask``: if adding special tokens, this is a list of [0, 1], with 0 specifying special added
|
||||||
tokens and 1 specifying sequence tokens.
|
tokens and 1 specifying sequence tokens.
|
||||||
"""
|
"""
|
||||||
@@ -849,23 +888,31 @@ class PreTrainedTokenizer(object):
|
|||||||
len_pair_ids = len(pair_ids) if pair else 0
|
len_pair_ids = len(pair_ids) if pair else 0
|
||||||
|
|
||||||
encoded_inputs = {}
|
encoded_inputs = {}
|
||||||
|
|
||||||
|
# Handle max sequence length
|
||||||
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
|
total_len = len_ids + len_pair_ids + (self.num_added_tokens(pair=pair) if add_special_tokens else 0)
|
||||||
if max_length and total_len > max_length:
|
if max_length and total_len > max_length:
|
||||||
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
|
ids, pair_ids, overflowing_tokens = self.truncate_sequences(ids, pair_ids=pair_ids,
|
||||||
num_tokens_to_remove=total_len-max_length,
|
num_tokens_to_remove=total_len-max_length,
|
||||||
truncation_strategy=truncation_strategy,
|
truncation_strategy=truncation_strategy,
|
||||||
stride=stride)
|
stride=stride)
|
||||||
|
if return_overflowing_tokens:
|
||||||
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
encoded_inputs["overflowing_tokens"] = overflowing_tokens
|
||||||
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
encoded_inputs["num_truncated_tokens"] = total_len - max_length
|
||||||
|
|
||||||
|
# Handle special_tokens
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
sequence = self.build_inputs_with_special_tokens(ids, pair_ids)
|
||||||
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
token_type_ids = self.create_token_type_ids_from_sequences(ids, pair_ids)
|
||||||
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
special_tokens_mask = self.get_special_tokens_mask(ids, pair_ids)
|
||||||
else:
|
else:
|
||||||
sequence = ids + pair_ids if pair else ids
|
sequence = ids + pair_ids if pair else ids
|
||||||
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
token_type_ids = [0] * len(ids) + ([1] * len(pair_ids) if pair else [])
|
||||||
|
special_tokens_mask = [0] * (len(ids) + (len(pair_ids) if pair else 0))
|
||||||
|
if return_special_tokens_mask:
|
||||||
|
encoded_inputs["special_tokens_mask"] = self.get_special_tokens_mask(ids, pair_ids)
|
||||||
|
|
||||||
|
# Prepare inputs as tensors if asked
|
||||||
if return_tensors == 'tf' and is_tf_available():
|
if return_tensors == 'tf' and is_tf_available():
|
||||||
sequence = tf.constant([sequence])
|
sequence = tf.constant([sequence])
|
||||||
token_type_ids = tf.constant([token_type_ids])
|
token_type_ids = tf.constant([token_type_ids])
|
||||||
@@ -876,11 +923,14 @@ class PreTrainedTokenizer(object):
|
|||||||
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
|
logger.warning("Unable to convert output to tensors format {}, PyTorch or TensorFlow is not available.".format(return_tensors))
|
||||||
|
|
||||||
encoded_inputs["input_ids"] = sequence
|
encoded_inputs["input_ids"] = sequence
|
||||||
|
if return_token_type_ids:
|
||||||
encoded_inputs["token_type_ids"] = token_type_ids
|
encoded_inputs["token_type_ids"] = token_type_ids
|
||||||
|
|
||||||
if max_length and len(encoded_inputs["input_ids"]) > max_length:
|
if max_length and len(encoded_inputs["input_ids"]) > max_length:
|
||||||
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
encoded_inputs["input_ids"] = encoded_inputs["input_ids"][:max_length]
|
||||||
|
if return_token_type_ids:
|
||||||
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
encoded_inputs["token_type_ids"] = encoded_inputs["token_type_ids"][:max_length]
|
||||||
|
if return_special_tokens_mask:
|
||||||
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
encoded_inputs["special_tokens_mask"] = encoded_inputs["special_tokens_mask"][:max_length]
|
||||||
|
|
||||||
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
|
if max_length is None and len(encoded_inputs["input_ids"]) > self.max_len:
|
||||||
|
|||||||
Reference in New Issue
Block a user