From 634a3172d869e2ff772b2e0813169641ca9e6cc5 Mon Sep 17 00:00:00 2001 From: LysandreJik Date: Mon, 12 Aug 2019 15:14:15 -0400 Subject: [PATCH] Added integration tests for sequence builders. --- .../tests/tokenization_bert_test.py | 11 +++++++++++ .../tests/tokenization_roberta_test.py | 14 +++++++++++++- .../tests/tokenization_xlm_test.py | 11 +++++++++++ .../tests/tokenization_xlnet_test.py | 12 ++++++++++++ 4 files changed, 47 insertions(+), 1 deletion(-) diff --git a/pytorch_transformers/tests/tokenization_bert_test.py b/pytorch_transformers/tests/tokenization_bert_test.py index 5eb39b729d..db507317a8 100644 --- a/pytorch_transformers/tests/tokenization_bert_test.py +++ b/pytorch_transformers/tests/tokenization_bert_test.py @@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertFalse(_is_punctuation(u"A")) self.assertFalse(_is_punctuation(u" ")) + def test_sequence_builders(self): + tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) + encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) + + assert encoded_sentence == [101] + text + [102] + assert encoded_pair == [101] + text + [102] + text_2 + [102] if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_roberta_test.py b/pytorch_transformers/tests/tokenization_roberta_test.py index daefea0fa7..b76b3e311d 100644 --- a/pytorch_transformers/tests/tokenization_roberta_test.py +++ b/pytorch_transformers/tests/tokenization_roberta_test.py @@ -71,10 +71,22 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester): [0, 31414, 232, 328, 2] ) self.assertListEqual( - tokenizer.encode('Hello world! cécé herlolip'), + tokenizer.encode('Hello world! cécé herlolip 418'), [0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2] ) + def test_sequence_builders(self): + tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) + encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) + + assert encoded_sentence == [0] + text + [2] + assert encoded_pair == [0] + text + [2, 2] + text_2 + [2] + if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_xlm_test.py b/pytorch_transformers/tests/tokenization_xlm_test.py index a20e92044f..ede77a1f98 100644 --- a/pytorch_transformers/tests/tokenization_xlm_test.py +++ b/pytorch_transformers/tests/tokenization_xlm_test.py @@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester): self.assertListEqual( tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens) + def test_sequence_builders(self): + tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048") + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) + encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) + + assert encoded_sentence == [1] + text + [1] + assert encoded_pair == [1] + text + [1] + text_2 + [1] if __name__ == '__main__': unittest.main() diff --git a/pytorch_transformers/tests/tokenization_xlnet_test.py b/pytorch_transformers/tests/tokenization_xlnet_test.py index 08e9e9cb2d..9feab7c0bd 100644 --- a/pytorch_transformers/tests/tokenization_xlnet_test.py +++ b/pytorch_transformers/tests/tokenization_xlnet_test.py @@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester): u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this', SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.']) + def test_sequence_builders(self): + tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased") + + text = tokenizer.encode("sequence builders") + text_2 = tokenizer.encode("multi-sequence build") + + encoded_sentence = tokenizer.add_special_tokens_single_sentence(text) + encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2) + + assert encoded_sentence == text + [4, 3] + assert encoded_pair == text + [4] + text_2 + [4, 3] + if __name__ == '__main__': unittest.main()