AutoTokenizer: infer the class from the tokenizer config if possible (#12208)

* AutoTokenizer: infer the class from the tokenizer config if possible

* Add tests

* Update src/transformers/models/auto/tokenization_auto.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger
2021-06-17 12:39:22 -04:00
committed by GitHub
parent 0daadc1919
commit adb70eda4d
3 changed files with 168 additions and 10 deletions

View File

@@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
if tokenizer_config_file is not None:
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
init_kwargs = json.load(tokenizer_config_handle)
init_kwargs.pop("tokenizer_class", None)
saved_init_inputs = init_kwargs.pop("init_inputs", ())
if not init_inputs:
init_inputs = saved_init_inputs
@@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
tokenizer_class = self.__class__.__name__
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
tokenizer_class = tokenizer_class[:-4]
tokenizer_config["tokenizer_class"] = tokenizer_class
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")