|
|
|
|
@@ -172,6 +172,35 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
self.assertEqual(len(tokens[key].shape), 2)
|
|
|
|
|
self.assertEqual(tokens[key].shape[-1], 6)
|
|
|
|
|
|
|
|
|
|
def assert_build_inputs_with_special_tokens(self, tokenizer_r, tokenizer_p):
|
|
|
|
|
# Input string
|
|
|
|
|
input_simple = tokenizer_p.tokenize("This is a sample input")
|
|
|
|
|
input_pair = tokenizer_p.tokenize("This is a sample pair")
|
|
|
|
|
|
|
|
|
|
# Generate output
|
|
|
|
|
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
|
|
|
|
|
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
|
|
|
|
|
self.assertEqual(output_p, output_r)
|
|
|
|
|
|
|
|
|
|
# Generate pair output
|
|
|
|
|
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
|
|
|
|
|
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
|
|
|
|
|
self.assertEqual(output_p, output_r)
|
|
|
|
|
|
|
|
|
|
# Input tokens id
|
|
|
|
|
input_simple = tokenizer_p.encode("This is a sample input")
|
|
|
|
|
input_pair = tokenizer_p.encode("This is a sample pair")
|
|
|
|
|
|
|
|
|
|
# Generate output
|
|
|
|
|
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple)
|
|
|
|
|
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple)
|
|
|
|
|
self.assertEqual(output_p, output_r)
|
|
|
|
|
|
|
|
|
|
# Generate pair output
|
|
|
|
|
output_r = tokenizer_r.build_inputs_with_special_tokens(input_simple, input_pair)
|
|
|
|
|
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
|
|
|
|
|
self.assertEqual(output_p, output_r)
|
|
|
|
|
|
|
|
|
|
def test_bert(self):
|
|
|
|
|
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
|
|
|
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
@@ -204,6 +233,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
@require_torch
|
|
|
|
|
def test_transfoxl(self):
|
|
|
|
|
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
|
|
|
|
@@ -237,6 +269,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
def test_distilbert(self):
|
|
|
|
|
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
|
|
|
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
@@ -270,6 +305,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
def test_gpt2(self):
|
|
|
|
|
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
|
|
|
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
@@ -302,6 +340,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
def test_roberta(self):
|
|
|
|
|
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
|
|
|
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
@@ -334,6 +375,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assert_batch_encode_dynamic_overflowing(tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
def test_openai(self):
|
|
|
|
|
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
|
|
|
|
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
|
|
|
|
@@ -366,6 +410,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|
|
|
|
# Check for dynamic encoding sequence handling in batch_encode_plus
|
|
|
|
|
self.assertRaises(ValueError, self.assert_batch_encode_dynamic_overflowing, tokenizer_r)
|
|
|
|
|
|
|
|
|
|
# Check alignment for build_inputs_with_special_tokens
|
|
|
|
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
unittest.main()
|
|
|
|
|
|