Make sure custom configs work with Transformers (#15569)

* Make sure custom configs work with Transformers

* Apply code review suggestions
This commit is contained in:
Sylvain Gugger
2022-02-09 10:04:44 -05:00
committed by GitHub
parent 7732d0fe7a
commit 1f60bc46f3
5 changed files with 37 additions and 6 deletions

View File

@@ -59,14 +59,14 @@ from transformers.testing_utils import (
sys.path.append(str(Path(__file__).parent.parent / "utils"))
from test_module.custom_configuration import CustomConfig # noqa E402
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
if is_torch_available():
import torch
from torch import nn
from test_module.custom_modeling import CustomModel
from test_module.custom_modeling import CustomModel, NoSuperInitModel
from transformers import (
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
@@ -2091,6 +2091,15 @@ class ModelUtilsTest(TestCasePlus):
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
self.assertEqual(model.dtype, torch.float16)
def test_no_super_init_config_and_model(self):
config = NoSuperInitConfig(attribute=32)
model = NoSuperInitModel(config)
with tempfile.TemporaryDirectory() as tmp_dir:
model.save_pretrained(tmp_dir)
model = NoSuperInitModel.from_pretrained(tmp_dir)
@require_torch
@is_staging_test