Fast Tokenizers save pretrained should return the list of generated file paths. (#2918)
* Correctly return the tuple of generated file(s) when calling save_pretrained Signed-off-by: Morgan Funtowicz <morgan@huggingface.co> * Quality and format. Signed-off-by: Morgan Funtowicz <morgan@huggingface.co>
This commit is contained in:
@@ -1822,4 +1822,5 @@ class PreTrainedTokenizerFast(PreTrainedTokenizer):
|
|||||||
folder, file = save_directory, self.vocab_files_names["vocab_file"]
|
folder, file = save_directory, self.vocab_files_names["vocab_file"]
|
||||||
else:
|
else:
|
||||||
folder, file = os.path.split(os.path.abspath(save_directory))
|
folder, file = os.path.split(os.path.abspath(save_directory))
|
||||||
self._tokenizer.save(folder, file)
|
|
||||||
|
return tuple(self._tokenizer.save(folder, file))
|
||||||
|
|||||||
@@ -236,6 +236,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
def test_transfoxl(self):
|
def test_transfoxl(self):
|
||||||
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
for tokenizer_name in TransfoXLTokenizer.pretrained_vocab_files_map["pretrained_vocab_file"].keys():
|
||||||
@@ -272,6 +275,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
def test_distilbert(self):
|
def test_distilbert(self):
|
||||||
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in DistilBertTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = DistilBertTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -308,6 +314,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
def test_gpt2(self):
|
def test_gpt2(self):
|
||||||
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in GPT2Tokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = GPT2Tokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -343,6 +352,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
def test_roberta(self):
|
def test_roberta(self):
|
||||||
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in RobertaTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = RobertaTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -378,6 +390,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
def test_openai(self):
|
def test_openai(self):
|
||||||
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
for tokenizer_name in OpenAIGPTTokenizer.pretrained_vocab_files_map["vocab_file"].keys():
|
||||||
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
tokenizer_p = OpenAIGPTTokenizer.from_pretrained(tokenizer_name)
|
||||||
@@ -413,6 +428,9 @@ class FastTokenizerMatchingTest(unittest.TestCase):
|
|||||||
# Check alignment for build_inputs_with_special_tokens
|
# Check alignment for build_inputs_with_special_tokens
|
||||||
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
self.assert_build_inputs_with_special_tokens(tokenizer_r, tokenizer_p)
|
||||||
|
|
||||||
|
# Check the number of returned files for save_vocabulary
|
||||||
|
self.assertEqual(len(tokenizer_r.save_vocabulary(".")), len(tokenizer_p.save_vocabulary(".")))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user