Save code of registered custom models (#15379)
* Allow dynamic modules to use relative imports * Work for configs * Fix last merge conflict * Save code of registered custom objects * Map strings to strings * Fix test * Add tokenizer * Rework tests * Tests * Ignore fixtures py files for tests * Tokenizer test + fix collection * With full path * Rework integration * Fix typo * Remove changes in conftest * Test for tokenizers * Add documentation * Update docs/source/custom_models.mdx Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add file structure and file content * Add more doc * Style * Update docs/source/custom_models.mdx Co-authored-by: Suraj Patil <surajp815@gmail.com> * Address review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from packaging import version
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .file_utils import (
|
||||
EntryNotFoundError,
|
||||
ExplicitEnum,
|
||||
@@ -1435,6 +1436,7 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
pretrained_vocab_files_map: Dict[str, Dict[str, str]] = {}
|
||||
pretrained_init_configuration: Dict[str, Dict[str, Any]] = {}
|
||||
max_model_input_sizes: Dict[str, Optional[int]] = {}
|
||||
_auto_class: Optional[str] = None
|
||||
|
||||
# first name has to correspond to main model input name
|
||||
# to make sure `tokenizer.pad(...)` works correctly
|
||||
@@ -2071,6 +2073,11 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
if getattr(self, "_processor_class", None) is not None:
|
||||
tokenizer_config["processor_class"] = self._processor_class
|
||||
|
||||
# If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
|
||||
# loaded from the Hub.
|
||||
if self._auto_class is not None:
|
||||
custom_object_save(self, save_directory, config=tokenizer_config)
|
||||
|
||||
with open(tokenizer_config_file, "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(tokenizer_config, ensure_ascii=False))
|
||||
logger.info(f"tokenizer config file saved in {tokenizer_config_file}")
|
||||
@@ -3391,6 +3398,26 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
"""
|
||||
yield
|
||||
|
||||
@classmethod
|
||||
def register_for_auto_class(cls, auto_class="AutoTokenizer"):
|
||||
"""
|
||||
Register this class with a given auto class. This should only be used for custom tokenizers as the ones in the
|
||||
library are already mapped with `AutoTokenizer`.
|
||||
|
||||
Args:
|
||||
auto_class (`str` or `type`, *optional*, defaults to `"AutoTokenizer"`):
|
||||
The auto class to register this new tokenizer with.
|
||||
"""
|
||||
if not isinstance(auto_class, str):
|
||||
auto_class = auto_class.__name__
|
||||
|
||||
import transformers.models.auto as auto_module
|
||||
|
||||
if not hasattr(auto_module, auto_class):
|
||||
raise ValueError(f"{auto_class} is not a valid auto class.")
|
||||
|
||||
cls._auto_class = auto_class
|
||||
|
||||
def prepare_seq2seq_batch(
|
||||
self,
|
||||
src_texts: List[str],
|
||||
|
||||
Reference in New Issue
Block a user