From dcc9bb3252fdad12897851f22f57c3130ba38a7a Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Thu, 19 Sep 2019 09:29:48 +0200 Subject: [PATCH] Modified encode to return only lists. Added a more complete encode_plus method --- pytorch_transformers/tokenization_utils.py | 111 +++++++++++++++++++-- 1 file changed, 104 insertions(+), 7 deletions(-) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 97cf242511..3a3ebd49be 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -535,7 +535,7 @@ class PreTrainedTokenizer(object): """ if pair: - initial_tokens_len = sum([len(encoded) for encoded in self.encode("This is a sequence", "This is another")]) + initial_tokens_len = len(self.encode("This is a sequence") + self.encode("This is another")) final_tokens = self.encode("This is a sequence", "This is another", add_special_tokens=True) # In some models (e.g. GPT-2), there is no sequence pair encoding. @@ -693,10 +693,39 @@ class PreTrainedTokenizer(object): def _convert_token_to_id(self, token): raise NotImplementedError - def encode(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): + def encode(self, text, text_pair=None, add_special_tokens=False, **kwargs): """ Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. - + + Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. + + Args: + text: The first sequence to be encoded. + text_pair: Optional second sequence to be encoded. + add_special_tokens: if set to ``True``, the sequences will be encoded with the special tokens relative + to their model. + """ + if text_pair is None: + if add_special_tokens: + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + return self.add_special_tokens_single_sentence(sequence_tokens) + else: + ids = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + return ids + + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] + + if add_special_tokens: + return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens) + else: + logger.warning("No special tokens were added. The two sequences have been concatenated.") + return first_sentence_tokens + second_sentence_tokens + + def encode_plus(self, text, text_pair=None, add_special_tokens=False, output_mask=False, max_length=None, **kwargs): + """ + Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary. + Same as doing ``self.convert_tokens_to_ids(self.tokenize(text))``. Args: @@ -709,6 +738,69 @@ class PreTrainedTokenizer(object): max_length: if set to a number, will limit the total sequence returned so that it has a maximum length. **kwargs: passed to the `self.tokenize()` method """ + + information = {} + + if text_pair is None: + n_added_tokens = self.num_added_tokens() + if add_special_tokens: + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + if max_length: + information["overflowing_tokens"] = sequence_tokens[max_length - n_added_tokens:] + sequence_tokens = sequence_tokens[:max_length - n_added_tokens] + sequence = self.add_special_tokens_single_sentence(sequence_tokens) + else: + sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) + if max_length: + information["overflowing_tokens"] = sequence_tokens[max_length:] + sequence_tokens = sequence_tokens[:max_length] + sequence = sequence_tokens + + if output_mask: + information["mask"] = [0] * len(sequence) + + information["sequence"] = sequence + else: + first_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text, **kwargs)] + second_sentence_tokens = [self._convert_token_to_id(token) for token in self.tokenize(text_pair, **kwargs)] + f_len, s_len = len(first_sentence_tokens), len(second_sentence_tokens) + n_added_tokens = self.num_added_tokens(pair=True) + + if add_special_tokens: + if max_length: + if len(first_sentence_tokens) + n_added_tokens >= max_length: + logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.") + else: + if f_len + s_len + self.num_added_tokens(pair=True) > max_length: + information["overflowing_tokens"] = second_sentence_tokens[max_length - f_len - n_added_tokens:] + second_sentence_tokens = second_sentence_tokens[:max_length - f_len - n_added_tokens] + + encoded_sequence = self.add_special_tokens_sentences_pair( + first_sentence_tokens, + second_sentence_tokens, + output_mask + ) + + if output_mask: + sequence, information["mask"] = encoded_sequence + else: + sequence = encoded_sequence + + information["sequence"] = sequence + else: + logger.warning("No special tokens were added. The two sequences have been concatenated.") + sequence = first_sentence_tokens + second_sentence_tokens + + if max_length: + information["overflowing_tokens"] = sequence[max_length:] + sequence = sequence[:max_length] + if output_mask: + information["mask"] = [0] * len(sequence) + + information["sequence"] = sequence + + return information + if text_pair is None: if add_special_tokens: sequence_tokens = self.convert_tokens_to_ids(self.tokenize(text, **kwargs)) @@ -725,12 +817,17 @@ class PreTrainedTokenizer(object): if add_special_tokens: if max_length: if len(first_sentence_tokens) + self.num_added_tokens(pair=True) >= max_length: - logger.warning("The first sequence is longer than the maximum specified length. This sequence will not be truncated.") + logger.warning( + "The first sequence is longer than the maximum specified length. This sequence will not be truncated.") else: - if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens(pair=True) > max_length: - second_sentence_tokens = second_sentence_tokens[:max_length - len(first_sentence_tokens) - self.num_added_tokens(pair=True)] + if len(second_sentence_tokens) + len(first_sentence_tokens) + self.num_added_tokens( + pair=True) > max_length: + second_sentence_tokens = second_sentence_tokens[ + :max_length - len(first_sentence_tokens) - self.num_added_tokens( + pair=True)] - return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask) + return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, + output_mask) else: if max_length: first_sentence_tokens = first_sentence_tokens[:max_length]