Added integration tests for sequence builders.
This commit is contained in:
@@ -125,6 +125,17 @@ class BertTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
self.assertFalse(_is_punctuation(u"A"))
|
self.assertFalse(_is_punctuation(u"A"))
|
||||||
self.assertFalse(_is_punctuation(u" "))
|
self.assertFalse(_is_punctuation(u" "))
|
||||||
|
|
||||||
|
def test_sequence_builders(self):
|
||||||
|
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||||
|
|
||||||
|
text = tokenizer.encode("sequence builders")
|
||||||
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
|
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||||
|
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||||
|
|
||||||
|
assert encoded_sentence == [101] + text + [102]
|
||||||
|
assert encoded_pair == [101] + text + [102] + text_2 + [102]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -71,10 +71,22 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
[0, 31414, 232, 328, 2]
|
[0, 31414, 232, 328, 2]
|
||||||
)
|
)
|
||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.encode('Hello world! cécé herlolip'),
|
tokenizer.encode('Hello world! cécé herlolip 418'),
|
||||||
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sequence_builders(self):
|
||||||
|
tokenizer = RobertaTokenizer.from_pretrained("roberta-base")
|
||||||
|
|
||||||
|
text = tokenizer.encode("sequence builders")
|
||||||
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
|
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||||
|
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||||
|
|
||||||
|
assert encoded_sentence == [0] + text + [2]
|
||||||
|
assert encoded_pair == [0] + text + [2, 2] + text_2 + [2]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -66,6 +66,17 @@ class XLMTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
self.assertListEqual(
|
self.assertListEqual(
|
||||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||||
|
|
||||||
|
def test_sequence_builders(self):
|
||||||
|
tokenizer = XLMTokenizer.from_pretrained("xlm-mlm-en-2048")
|
||||||
|
|
||||||
|
text = tokenizer.encode("sequence builders")
|
||||||
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
|
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||||
|
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||||
|
|
||||||
|
assert encoded_sentence == [1] + text + [1]
|
||||||
|
assert encoded_pair == [1] + text + [1] + text_2 + [1]
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -89,6 +89,18 @@ class XLNetTokenizationTest(CommonTestCases.CommonTokenizerTester):
|
|||||||
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
|
||||||
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
|
||||||
|
|
||||||
|
def test_sequence_builders(self):
|
||||||
|
tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")
|
||||||
|
|
||||||
|
text = tokenizer.encode("sequence builders")
|
||||||
|
text_2 = tokenizer.encode("multi-sequence build")
|
||||||
|
|
||||||
|
encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
|
||||||
|
encoded_pair = tokenizer.add_special_tokens_sentences_pair(text, text_2)
|
||||||
|
|
||||||
|
assert encoded_sentence == text + [4, 3]
|
||||||
|
assert encoded_pair == text + [4] + text_2 + [4, 3]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user