Modified encode to return only lists. Added a more complete encode_plus method
This commit is contained in:
@@ -535,7 +535,7 @@ class PreTrainedTokenizer(object):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
if pair:
|
if pair:
|
||||||
initial_tokens_len = sum([len(encoded) for encoded in self.encode("This is a sequence", "This is another")])
|
initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another"))
|
||||||
final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True)
|
final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True)
|
||||||
|
|
||||||
# In some models (e.g. GPT-2), there is no sequence pair encoding.
|
# In some models (e.g. GPT-2), there is no sequence pair encoding.
|
||||||
@@ -693,10 +693,39 @@ class PreTrainedTokenizer(object):
|
|||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
|
def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||||
|
|
||||||
|
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: The first sequence to be encoded.
|
||||||
|
text_pair: Optional second sequence to be encoded.
|
||||||
|
add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative
|
||||||
|
to their model.
|
||||||
|
"""
|
||||||
|
if text_pair is None:
|
||||||
|
if add_special_tokens:
|
||||||
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
|
return self.add_special_tokens_single_sentence(sequence_tokens)
|
||||||
|
else:
|
||||||
|
ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
|
return ids
|
||||||
|
|
||||||
|
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||||
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||||
|
|
||||||
|
if add_special_tokens:
|
||||||
|
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens)
|
||||||
|
else:
|
||||||
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
|
return first_sentence_tokens + second_sentence_tokens
|
||||||
|
|
||||||
|
def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs):
|
||||||
|
"""
|
||||||
|
Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
|
||||||
|
|
||||||
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -709,6 +738,69 @@ class PreTrainedTokenizer(object):
|
|||||||
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
max_length: if set to a number, will limit the total sequence returned so that it has a maximum length.
|
||||||
**kwargs: passed to the `self.tokenize()` method
|
**kwargs: passed to the `self.tokenize()` method
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
information = {}
|
||||||
|
|
||||||
|
if text_pair is None:
|
||||||
|
n_added_tokens = self.num_added_tokens()
|
||||||
|
if add_special_tokens:
|
||||||
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
|
if max_length:
|
||||||
|
information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:]
|
||||||
|
sequence_tokens = sequence_tokens[:max_length - n_added_tokens]
|
||||||
|
sequence = self.add_special_tokens_single_sentence(sequence_tokens)
|
||||||
|
else:
|
||||||
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
|
if max_length:
|
||||||
|
information["overflowing_tokens"] = sequence_tokens[max_length:]
|
||||||
|
sequence_tokens = sequence_tokens[:max_length]
|
||||||
|
sequence = sequence_tokens
|
||||||
|
|
||||||
|
if output_mask:
|
||||||
|
information["mask"] = [0] * len(sequence)
|
||||||
|
|
||||||
|
information["sequence"] = sequence
|
||||||
|
else:
|
||||||
|
first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)]
|
||||||
|
second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)]
|
||||||
|
f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens)
|
||||||
|
n_added_tokens = self.num_added_tokens(pair=True)
|
||||||
|
|
||||||
|
if add_special_tokens:
|
||||||
|
if max_length:
|
||||||
|
if len(first_sentence_tokens) + n_added_tokens >= max_length:
|
||||||
|
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||||
|
else:
|
||||||
|
if f_len + s_len + self.num_added_tokens(pair=True) > max_length:
|
||||||
|
information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:]
|
||||||
|
second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens]
|
||||||
|
|
||||||
|
encoded_sequence = self.add_special_tokens_sentences_pair(
|
||||||
|
first_sentence_tokens,
|
||||||
|
second_sentence_tokens,
|
||||||
|
output_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
if output_mask:
|
||||||
|
sequence, information["mask"] = encoded_sequence
|
||||||
|
else:
|
||||||
|
sequence = encoded_sequence
|
||||||
|
|
||||||
|
information["sequence"] = sequence
|
||||||
|
else:
|
||||||
|
logger.warning("No special tokens were added. The two sequences have been concatenated.")
|
||||||
|
sequence = first_sentence_tokens + second_sentence_tokens
|
||||||
|
|
||||||
|
if max_length:
|
||||||
|
information["overflowing_tokens"] = sequence[max_length:]
|
||||||
|
sequence = sequence[:max_length]
|
||||||
|
if output_mask:
|
||||||
|
information["mask"] = [0] * len(sequence)
|
||||||
|
|
||||||
|
information["sequence"] = sequence
|
||||||
|
|
||||||
|
return information
|
||||||
|
|
||||||
if text_pair is None:
|
if text_pair is None:
|
||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs))
|
||||||
@@ -725,12 +817,17 @@ class PreTrainedTokenizer(object):
|
|||||||
if add_special_tokens:
|
if add_special_tokens:
|
||||||
if max_length:
|
if max_length:
|
||||||
if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length:
|
if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length:
|
||||||
logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
logger.warning(
|
||||||
|
"The first sequence is longer than the maximum specified length. This sequence will not be truncated.")
|
||||||
else:
|
else:
|
||||||
if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length:
|
if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(
|
||||||
second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)]
|
pair=True) > max_length:
|
||||||
|
second_sentence_tokens = second_sentence_tokens[
|
||||||
|
:max_length - len(first_sentence_tokens) - self.num_added_tokens(
|
||||||
|
pair=True)]
|
||||||
|
|
||||||
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask)
|
return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens,
|
||||||
|
output_mask)
|
||||||
else:
|
else:
|
||||||
if max_length:
|
if max_length:
|
||||||
first_sentence_tokens = first_sentence_tokens[:max_length]
|
first_sentence_tokens = first_sentence_tokens[:max_length]
|
||||||
|
|||||||
Reference in New Issue
Block a user