Save code of registered custom models (#15379)
* Allow dynamic modules to use relative imports * Work for configs * Fix last merge conflict * Save code of registered custom objects * Map strings to strings * Fix test * Add tokenizer * Rework tests * Tests * Ignore fixtures py files for tests * Tokenizer test + fix collection * With full path * Rework integration * Fix typo * Remove changes in conftest * Test for tokenizers * Add documentation * Update docs/source/custom_models.mdx Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * Add file structure and file content * Add more doc * Style * Update docs/source/custom_models.mdx Co-authored-by: Suraj Patil <surajp815@gmail.com> * Address review comments Co-authored-by: Lysandre Debut <lysandre@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
0
utils/test_module/__init__.py
Normal file
0
utils/test_module/__init__.py
Normal file
9
utils/test_module/custom_configuration.py
Normal file
9
utils/test_module/custom_configuration.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class CustomConfig(PretrainedConfig):
|
||||
model_type = "custom"
|
||||
|
||||
def __init__(self, attribute=1, **kwargs):
|
||||
self.attribute = attribute
|
||||
super().__init__(**kwargs)
|
||||
20
utils/test_module/custom_modeling.py
Normal file
20
utils/test_module/custom_modeling.py
Normal file
@@ -0,0 +1,20 @@
|
||||
import torch
|
||||
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
from .custom_configuration import CustomConfig
|
||||
|
||||
|
||||
class CustomModel(PreTrainedModel):
|
||||
config_class = CustomConfig
|
||||
base_model_prefix = "custom"
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.linear = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
def _init_weights(self, module):
|
||||
pass
|
||||
5
utils/test_module/custom_tokenization.py
Normal file
5
utils/test_module/custom_tokenization.py
Normal file
@@ -0,0 +1,5 @@
|
||||
from transformers import BertTokenizer
|
||||
|
||||
|
||||
class CustomTokenizer(BertTokenizer):
|
||||
pass
|
||||
8
utils/test_module/custom_tokenization_fast.py
Normal file
8
utils/test_module/custom_tokenization_fast.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from transformers import BertTokenizerFast
|
||||
|
||||
from .custom_tokenization import CustomTokenizer
|
||||
|
||||
|
||||
class CustomTokenizerFast(BertTokenizerFast):
|
||||
slow_tokenizer_class = CustomTokenizer
|
||||
pass
|
||||
Reference in New Issue
Block a user