diff --git a/pytorch_transformers/tests/tokenization_tests_commons.py b/pytorch_transformers/tests/tokenization_tests_commons.py index 3da0494ac4..aa6746a758 100644 --- a/pytorch_transformers/tests/tokenization_tests_commons.py +++ b/pytorch_transformers/tests/tokenization_tests_commons.py @@ -186,3 +186,15 @@ class CommonTestCases: for weights_list_2 in weights_lists_2: self.assertListEqual(weights_list, weights_list_2) + + def test_mask_output(self): + if sys.version_info <= (3, 0): + return + + tokenizer = self.get_tokenizer() + + if tokenizer.add_special_tokens_sentences_pair.__qualname__.split('.')[0] != "PreTrainedTokenizer": + seq_0 = "Test this method." + seq_1 = "With these inputs." + sequences, mask = tokenizer.encode(seq_0, seq_1, add_special_tokens=True, output_mask=True) + assert len(sequences) == len(mask) diff --git a/pytorch_transformers/tokenization_utils.py b/pytorch_transformers/tokenization_utils.py index 22ab04ac6a..be49d7eab5 100644 --- a/pytorch_transformers/tokenization_utils.py +++ b/pytorch_transformers/tokenization_utils.py @@ -690,6 +690,8 @@ class PreTrainedTokenizer(object): if add_special_tokens: return self.add_special_tokens_sentences_pair(first_sentence_tokens, second_sentence_tokens, output_mask) else: + if output_mask: + logger.warning("Can't output mask if no special tokens are involved. Please call the method with add_special_tokens set to True.") return first_sentence_tokens, second_sentence_tokens def add_special_tokens_single_sentence(self, token_ids):