Fix for fast tokenizers save_pretrained compatibility with Python. (#2933)

* Renamed file generate by tokenizers when calling save_pretrained to match python.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added save_vocabulary tests.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Remove python quick and dirty fix for clean Rust impl.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Bump tokenizers dependency to 0.5.1

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* TransfoXLTokenizerFast uses a json vocabulary file + warning about incompatibility between Python and Rust

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Added some save_pretrained / from_pretrained unittests.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Update tokenizers to 0.5.2

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Quality and format.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* flake8

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Making sure there is really a bug in unittest

* Fix TransfoXL constructor vocab_file / pretrained_vocab_file mixin.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Funtowicz Morgan
2020-02-25 00:20:42 +01:00
committed by GitHub
parent ee60840ee6
commit 4cd9c0971c
4 changed files with 83 additions and 24 deletions

View File

@@ -258,6 +258,20 @@ class FastTokenizerMatchingTest(unittest.TestCase):
output_p = tokenizer_p.build_inputs_with_special_tokens(input_simple, input_pair)
self.assertEqual(output_p, output_r)
def assert_save_pretrained(self, tokenizer_r, tokenizer_p):
# Checks it save with the same files
self.assertSequenceEqual(tokenizer_r.save_vocabulary("."), tokenizer_p.save_vocabulary("."))
# Checks everything loads correctly in the same way
tokenizer_rp, tokenizer_pp = tokenizer_r.from_pretrained("."), tokenizer_p.from_pretrained(".")
# Check special tokens are set accordingly on Rust and Python
for key in tokenizer_pp.special_tokens_map:
self.assertTrue(hasattr(tokenizer_rp, key))
# self.assertEqual(getattr(tokenizer_rp, key), getattr(tokenizer_pp, key))
# self.assertEqual(getattr(tokenizer_rp, key + "_id"), getattr(tokenizer_pp, key + "_id"))
def test_bert(self):
for tokenizer_name in BertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = BertTokenizer.from_pretrained(tokenizer_name)
@@ -294,7 +308,7 @@ class FastTokenizerMatchingTest(unittest.TestCase):
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
@@ -335,12 +349,26 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
# TransfoXL tokenizers comes in a special format which is not compatible at all
# with rust tokenizers. We ensure the errors detection at correctly raised
tokenizer_r_files = tokenizer_r.save_pretrained(".")
self.assertSequenceEqual(
tokenizer_r_files, ["./vocab.json", "./special_tokens_map.json", "./added_tokens.json"]
)
# Check loading Python-tokenizer save through Rust doesnt work (and the opposite)
self.assertRaises(ValueError, tokenizer_p.from_pretrained, *tokenizer_r_files)
self.assertRaises(ValueError, tokenizer_r.from_pretrained, *tokenizer_p.save_pretrained("."))
# Check loading works for Python to Python and Rust to Rust
# Issue: https://github.com/huggingface/transformers/issues/3000
# self.assertIsNotNone(tokenizer_p.__class__.from_pretrained('./'))
self.assertIsNotNone(tokenizer_r.__class__.from_pretrained("./"))
def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
@@ -378,7 +406,7 @@ class FastTokenizerMatchingTest(unittest.TestCase):
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
self.assert_padding(tokenizer_r, tokenizer_p)
@@ -419,7 +447,7 @@ class FastTokenizerMatchingTest(unittest.TestCase):
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
@@ -460,7 +488,7 @@ class FastTokenizerMatchingTest(unittest.TestCase):
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
self.assert_save_pretrained(tokenizer_r, tokenizer_p)
# Check for padding
# TODO: Re-enable this test as soon as Roberta align with the python tokenizer.
@@ -501,12 +529,10 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
# Check for padding
self.assertRaises(ValueError, self.assert_padding, tokenizer_r, tokenizer_p)
if __name__ == "__main__":
unittest.main()
# Check the number of returned files for save_vocabulary
self.assert_save_pretrained(tokenizer_r, tokenizer_p)