Fix tokenizer saving and loading error (#6026)
* fix tokenizer saving and loading bugs when adding AddedToken to additional special tokens * Add tokenizer test * Style * Style 2 Co-authored-by: Lysandre <lysandre.debut@reseau.eseo.fr>
This commit is contained in:
@@ -1562,6 +1562,8 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
for key, value in special_tokens_map.items():
|
for key, value in special_tokens_map.items():
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
value = AddedToken(**value)
|
value = AddedToken(**value)
|
||||||
|
elif isinstance(value, list):
|
||||||
|
value = [AddedToken(**token) if isinstance(token, dict) else token for token in value]
|
||||||
setattr(tokenizer, key, value)
|
setattr(tokenizer, key, value)
|
||||||
|
|
||||||
# Add supplementary tokens.
|
# Add supplementary tokens.
|
||||||
@@ -1633,6 +1635,10 @@ class PreTrainedTokenizerBase(SpecialTokensMixin):
|
|||||||
for key, value in self.special_tokens_map_extended.items():
|
for key, value in self.special_tokens_map_extended.items():
|
||||||
if isinstance(value, AddedToken):
|
if isinstance(value, AddedToken):
|
||||||
write_dict[key] = value.__getstate__()
|
write_dict[key] = value.__getstate__()
|
||||||
|
elif isinstance(value, list):
|
||||||
|
write_dict[key] = [
|
||||||
|
token.__getstate__() if isinstance(token, AddedToken) else token for token in value
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
write_dict[key] = value
|
write_dict[key] = value
|
||||||
f.write(json.dumps(write_dict, ensure_ascii=False))
|
f.write(json.dumps(write_dict, ensure_ascii=False))
|
||||||
|
|||||||
@@ -1165,6 +1165,16 @@ class TokenizerTesterMixin:
|
|||||||
encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key],
|
encoded_sequences_batch_padded_1[key], encoded_sequences_batch_padded_2[key],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_added_token_serializable(self):
|
||||||
|
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||||
|
for tokenizer in tokenizers:
|
||||||
|
new_token = AddedToken("new_token", lstrip=True)
|
||||||
|
tokenizer.add_special_tokens({"additional_special_tokens": [new_token]})
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||||
|
tokenizer.save_pretrained(tmp_dir_name)
|
||||||
|
tokenizer.from_pretrained(tmp_dir_name)
|
||||||
|
|
||||||
def test_batch_encode_plus_padding(self):
|
def test_batch_encode_plus_padding(self):
|
||||||
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus
|
# Test that padded sequences are equivalent between batch_encode_plus and encode_plus
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user