AutoTokenizer: infer the class from the tokenizer config if possible (#12208)
* AutoTokenizer: infer the class from the tokenizer config if possible * Add tests * Update src/transformers/models/auto/tokenization_auto.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1745,6 +1745,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if tokenizer_config_file is not None:
|
||||
with open(tokenizer_config_file, encoding="utf-8") as tokenizer_config_handle:
|
||||
init_kwargs = json.load(tokenizer_config_handle)
|
||||
init_kwargs.pop("tokenizer_class", None)
|
||||
saved_init_inputs = init_kwargs.pop("init_inputs", ())
|
||||
if not init_inputs:
|
||||
init_inputs = saved_init_inputs
|
||||
@@ -1920,6 +1921,14 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
|
||||
# add_type_field=True to allow dicts in the kwargs / differentiate from AddedToken serialization
|
||||
tokenizer_config = convert_added_tokens(tokenizer_config, add_type_field=True)
|
||||
|
||||
# Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained
|
||||
tokenizer_class = self.__class__.__name__
|
||||
# Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast`
|
||||
if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast":
|
||||
tokenizer_class = tokenizer_class[:-4]
|
||||
tokenizer_config["tokenizer_class"] = tokenizer_class
|
||||
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
|
||||
|
||||
Reference in New Issue
Block a user