Save code of registered custom models (#15379)

* Allow dynamic modules to use relative imports

* Work for configs

* Fix last merge conflict

* Save code of registered custom objects

* Map strings to strings

* Fix test

* Add tokenizer

* Rework tests

* Tests

* Ignore fixtures py files for tests

* Tokenizer test + fix collection

* With full path

* Rework integration

* Fix typo

* Remove changes in conftest

* Test for tokenizers

* Add documentation

* Update docs/source/custom_models.mdx

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* Add file structure and file content

* Add more doc

* Style

* Update docs/source/custom_models.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Address review comments

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Sylvain Gugger
2022-02-02 10:44:37 -05:00
committed by GitHub
parent 623d8cb475
commit 44b21f117b
23 changed files with 630 additions and 295 deletions

View File

View File

@@ -0,0 +1,9 @@
from transformers import PretrainedConfig
class CustomConfig(PretrainedConfig):
model_type = "custom"
def __init__(self, attribute=1, **kwargs):
self.attribute = attribute
super().__init__(**kwargs)

View File

@@ -0,0 +1,20 @@
import torch
from transformers import PreTrainedModel
from .custom_configuration import CustomConfig
class CustomModel(PreTrainedModel):
config_class = CustomConfig
base_model_prefix = "custom"
def __init__(self, config):
super().__init__(config)
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
def forward(self, x):
return self.linear(x)
def _init_weights(self, module):
pass

View File

@@ -0,0 +1,5 @@
from transformers import BertTokenizer
class CustomTokenizer(BertTokenizer):
pass

View File

@@ -0,0 +1,8 @@
from transformers import BertTokenizerFast
from .custom_tokenization import CustomTokenizer
class CustomTokenizerFast(BertTokenizerFast):
slow_tokenizer_class = CustomTokenizer
pass