Updated DistilBERT
This commit is contained in:
@@ -536,13 +536,7 @@ class PreTrainedTokenizer(object):
|
||||
|
||||
if pair:
|
||||
initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
|
||||
final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True)
|
||||
|
||||
# In some models (e.g. GPT-2), there is no sequence pair encoding.
|
||||
if len(final_tokens) == 2:
|
||||
return 0
|
||||
else:
|
||||
final_tokens_len = len(final_tokens)
|
||||
final_tokens_len = len(self.encode("This is a sequence", "This is another", add_special_tokens=True))
|
||||
else:
|
||||
initial_tokens_len = len(self.encode("This is a sequence"))
|
||||
final_tokens_len = len(self.encode("This is a sequence", add_special_tokens=True))
|
||||
@@ -700,86 +694,93 @@ class PreTrainedTokenizer(object):
|
||||
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||
|
||||
Args:
|
||||
text: The first sequence to be encoded.
|
||||
text_pair: Optional second sequence to be encoded.
|
||||
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
|
||||
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
|
||||
method)
|
||||
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
|
||||
string using the `tokenize` method) or a list of integers (tokenized string ids using the
|
||||
`convert_tokens_to_ids` method)
|
||||
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
||||
to their model.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
if text_pair is None:
|
||||
if add_special_tokens:
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, six.string_types) else text
|
||||
return self.add_special_tokens_single_sequence(sequence_tokens)
|
||||
else:
|
||||
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, six.string_types) else text
|
||||
return ids
|
||||
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, six.string_types) else text
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, six.string_types) else text_pair
|
||||
|
||||
if add_special_tokens:
|
||||
return self.add_special_tokens_sequence_pair(first_sentence_tokens, second_sentence_tokens)
|
||||
else:
|
||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||
return first_sentence_tokens + second_sentence_tokens
|
||||
return self.encode_plus(text, text_pair, add_special_tokens, **kwargs)["input_ids"]
|
||||
|
||||
def encode_plus(self,
|
||||
text,
|
||||
text_pair=None,
|
||||
add_special_tokens=False,
|
||||
output_mask=False,
|
||||
output_token_type=False,
|
||||
max_length=None,
|
||||
stride=0,
|
||||
truncate_second_sequence_first=True,
|
||||
truncate_first_sequence=True,
|
||||
**kwargs):
|
||||
"""
|
||||
Returns a dictionary containing the encoded sequence or sequence pair. Other values can be returned by this
|
||||
method: the mask for sequence classification and the overflowing elements if a ``max_length`` is specified.
|
||||
|
||||
Args:
|
||||
text: The first sequence to be encoded.
|
||||
text_pair: Optional second sequence to be encoded.
|
||||
text: The first sequence to be encoded. This can be a string, a list of strings (tokenized string using
|
||||
the `tokenize` method) or a list of integers (tokenized string ids using the `convert_tokens_to_ids`
|
||||
method)
|
||||
text_pair: Optional second sequence to be encoded. This can be a string, a list of strings (tokenized
|
||||
string using the `tokenize` method) or a list of integers (tokenized string ids using the
|
||||
`convert_tokens_to_ids` method)
|
||||
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
||||
to their model.
|
||||
output_mask: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
|
||||
output_token_type: if set to ``True``, returns the text pair corresponding mask with 0 for the first sequence,
|
||||
and 1 for the second.
|
||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||
If there are overflowing tokens, those will be added to the returned dictionary
|
||||
stride: if set to a number along with max_length, the overflowing tokens returned will contain some tokens
|
||||
from the main sequence returned. The value of this argument defined the number of additional tokens.
|
||||
truncate_second_sequence_first: if there is a specified max_length, this flag will choose which sequence
|
||||
truncate_first_sequence: if there is a specified max_length, this flag will choose which sequence
|
||||
will be truncated.
|
||||
**kwargs: passed to the `self.tokenize()` method
|
||||
"""
|
||||
|
||||
information = {}
|
||||
|
||||
def get_input_ids(text):
|
||||
if isinstance(text, six.string_types):
|
||||
input_ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], str):
|
||||
input_ids = self.convert_tokens_to_ids(text)
|
||||
elif isinstance(text, (list, tuple)) and len(text) > 0 and isinstance(text[0], int):
|
||||
input_ids = text
|
||||
else:
|
||||
raise ValueError("Input is not valid. Should be a string, a list/tuple of strings or a list/tuple of integers.")
|
||||
|
||||
return input_ids
|
||||
|
||||
if text_pair is None:
|
||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) if isinstance(text, six.string_types) else text
|
||||
sequence_tokens = get_input_ids(text)
|
||||
|
||||
if add_special_tokens:
|
||||
information = self.prepare_for_model(sequence_tokens, max_length, stride)
|
||||
information = self.prepare_for_model(sequence_tokens, max_length=max_length, stride=stride)
|
||||
else:
|
||||
if max_length:
|
||||
information["overflowing_tokens"] = sequence_tokens[max_length - stride:]
|
||||
sequence_tokens = sequence_tokens[:max_length]
|
||||
information["sequence"] = sequence_tokens
|
||||
information["input_ids"] = sequence_tokens
|
||||
|
||||
if output_mask:
|
||||
information["mask"] = [0] * len(information["sequence"])
|
||||
if output_token_type:
|
||||
information["output_token_type"] = [0] * len(information["input_ids"])
|
||||
else:
|
||||
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] if isinstance(text, six.string_types) else text
|
||||
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] if isinstance(text_pair, six.string_types) else text_pair
|
||||
first_sentence_tokens = get_input_ids(text)
|
||||
second_sentence_tokens = get_input_ids(text_pair)
|
||||
|
||||
if add_special_tokens:
|
||||
information = self.prepare_pair_for_model(
|
||||
first_sentence_tokens,
|
||||
second_sentence_tokens,
|
||||
max_length,
|
||||
truncate_second_sequence_first,
|
||||
stride
|
||||
max_length=max_length,
|
||||
truncate_first_sequence=truncate_first_sequence,
|
||||
stride=stride
|
||||
)
|
||||
|
||||
if output_mask:
|
||||
information["mask"] = self.create_mask_from_sequences(text, text_pair)
|
||||
if output_token_type:
|
||||
information["output_token_type"] = self.create_mask_from_sequences(text, text_pair)
|
||||
else:
|
||||
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||
sequence = first_sentence_tokens + second_sentence_tokens
|
||||
@@ -787,43 +788,78 @@ class PreTrainedTokenizer(object):
|
||||
if max_length:
|
||||
information["overflowing_tokens"] = sequence[max_length - stride:]
|
||||
sequence = sequence[:max_length]
|
||||
if output_mask:
|
||||
information["mask"] = [0] * len(sequence)
|
||||
if output_token_type:
|
||||
information["output_token_type"] = [0] * len(sequence)
|
||||
|
||||
information["sequence"] = sequence
|
||||
information["input_ids"] = sequence
|
||||
|
||||
return information
|
||||
|
||||
def prepare_for_model(self, ids, max_length=None, stride=0):
|
||||
"""
|
||||
Prepares a list of tokenized input ids so that it can be used by the model. It adds special tokens, truncates
|
||||
sequences if overflowing while taking into account the special tokens and manages a window stride for
|
||||
overflowing tokens
|
||||
|
||||
Args:
|
||||
ids: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
|
||||
stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
|
||||
list of inputs.
|
||||
|
||||
Return:
|
||||
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
|
||||
"""
|
||||
information = {}
|
||||
n_added_tokens = self.num_added_tokens()
|
||||
if max_length:
|
||||
n_added_tokens = self.num_added_tokens()
|
||||
information["overflowing_tokens"] = ids[max_length - n_added_tokens - stride:]
|
||||
ids = ids[:max_length - n_added_tokens]
|
||||
information["sequence"] = self.add_special_tokens_single_sequence(ids)
|
||||
information["input_ids"] = self.add_special_tokens_single_sequence(ids)
|
||||
|
||||
return information
|
||||
|
||||
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_second_sequence_first=True, stride=0):
|
||||
def prepare_pair_for_model(self, ids_0, ids_1, max_length=None, truncate_first_sequence=True, stride=0):
|
||||
"""
|
||||
Prepares a list of tokenized input ids pair so that it can be used by the model. It adds special tokens,
|
||||
truncates sequences if overflowing while taking into account the special tokens and manages a window stride for
|
||||
overflowing tokens
|
||||
|
||||
Args:
|
||||
ids_0: list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
ids_1: second list of tokenized input ids. Can be obtained from a string by chaining the
|
||||
`tokenize` and `convert_tokens_to_ids` methods.
|
||||
max_length: maximum length of the returned list. Will truncate by taking into account the special tokens.
|
||||
truncate_first_sequence: if set to `True`, alongside a specified `max_length`, will truncate the first
|
||||
sequence if the total size is superior than the specified `max_length`. If set to `False`, will
|
||||
truncate the second sequence instead.
|
||||
stride: window stride for overflowing tokens. Can be useful for edge effect removal when using sequential
|
||||
list of inputs.
|
||||
|
||||
Return:
|
||||
a dictionary containing the `input_ids` as well as the `overflowing_tokens` if a `max_length` was given.
|
||||
"""
|
||||
f_len, s_len = len(ids_0), len(ids_1)
|
||||
n_added_tokens = self.num_added_tokens(pair=True)
|
||||
information = {}
|
||||
|
||||
if max_length:
|
||||
n_added_tokens = self.num_added_tokens(pair=True)
|
||||
if len(ids_0) + n_added_tokens >= max_length:
|
||||
logger.warning(
|
||||
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||
else:
|
||||
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
||||
if truncate_second_sequence_first:
|
||||
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
|
||||
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
|
||||
else:
|
||||
if truncate_first_sequence:
|
||||
information["overflowing_tokens"] = ids_0[max_length - s_len - n_added_tokens - stride:]
|
||||
ids_0 = ids_0[:max_length - s_len - n_added_tokens]
|
||||
else:
|
||||
information["overflowing_tokens"] = ids_1[max_length - f_len - n_added_tokens - stride:]
|
||||
ids_1 = ids_1[:max_length - f_len - n_added_tokens]
|
||||
|
||||
sequence = self.add_special_tokens_sequence_pair(ids_0, ids_1)
|
||||
information["sequence"] = sequence
|
||||
information["input_ids"] = sequence
|
||||
|
||||
return information
|
||||
|
||||
|
||||
Reference in New Issue
Block a user