From d490b5d5003654f104af3abd0556e598335b5650 Mon Sep 17 00:00:00 2001 From: Funtowicz Morgan Date: Thu, 20 Feb 2020 00:58:04 +0100 Subject: [PATCH] Fast Tokenizers save pretrained should return the list of generated file paths. (#2918) * Correctly return the tuple of generated file(s) when calling save_pretrained Signed-off-by: Morgan Funtowicz * Quality and format. Signed-off-by: Morgan Funtowicz --- src/transformers/tokenization_utils.py | 3 ++- tests/test_tokenization_fast.py | 18 ++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/tokenization_utils.py b/src/transformers/tokenization_utils.py index 761b1d29bb..e3af47b037 100644 --- a/src/transformers/tokenization_utils.py +++ b/src/transformers/tokenization_utils.py @@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer): folder, file = save_directory, self.vocab_files_names["vocab_file"] else: folder, file = os.path.split(os.path.abspath(save_directory)) - self._tokenizer.save(folder, file) + + return tuple(self._tokenizer.save(folder, file)) diff --git a/tests/test_tokenization_fast.py b/tests/test_tokenization_fast.py index fdcebd0117..18a7eac00a 100644 --- a/tests/test_tokenization_fast.py +++ b/tests/test_tokenization_fast.py @@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + @require_torch def test_transfoxl(self): for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys(): @@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + def test_distilbert(self): for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name) @@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + def test_gpt2(self): for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name) @@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + def test_roberta(self): for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name) @@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + def test_openai(self): for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys(): tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name) @@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase): # Check alignment for build_inputs_with_special_tokens self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p) + # Check the number of returned files for save_vocabulary + self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary("."))) + if __name__ == "__main__": unittest.main()