Fast Tokenizers save pretrained should return the list of generated file paths. (#2918)

* Correctly return the tuple of generated file(s) when calling save_pretrained

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>

* Quality and format.

Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
Funtowicz Morgan
2020-02-20 00:58:04 +01:00
committed by GitHub
parent 2708b44ee9
commit d490b5d500
2 changed files with 20 additions and 1 deletions

View File

@@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
folder, file = save_directory, self.vocab_files_names["vocab_file"]
else:
folder, file = os.path.split(os.path.abspath(save_directory))
self._tokenizer.save(folder, file)
return tuple(self._tokenizer.save(folder, file))

View File

@@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
@require_torch
def test_transfoxl(self):
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
@@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_distilbert(self):
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
@@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_gpt2(self):
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
@@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_roberta(self):
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
@@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
def test_openai(self):
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
@@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
# Check alignment for build_inputs_with_special_tokens
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
# Check the number of returned files for save_vocabulary
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
if __name__ == "__main__":
unittest.main()