Make sure custom configs work with Transformers (#15569)
* Make sure custom configs work with Transformers * Apply code review suggestions
This commit is contained in:
@@ -59,14 +59,14 @@ from transformers.testing_utils import (
|
||||
|
||||
sys.path.append(str(Path(__file__).parent.parent / "utils"))
|
||||
|
||||
from test_module.custom_configuration import CustomConfig # noqa E402
|
||||
from test_module.custom_configuration import CustomConfig, NoSuperInitConfig # noqa E402
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from test_module.custom_modeling import CustomModel
|
||||
from test_module.custom_modeling import CustomModel, NoSuperInitModel
|
||||
from transformers import (
|
||||
BERT_PRETRAINED_MODEL_ARCHIVE_LIST,
|
||||
MODEL_FOR_CAUSAL_IMAGE_MODELING_MAPPING,
|
||||
@@ -2091,6 +2091,15 @@ class ModelUtilsTest(TestCasePlus):
|
||||
model = AutoModel.from_pretrained(TINY_T5, torch_dtype=torch.float16)
|
||||
self.assertEqual(model.dtype, torch.float16)
|
||||
|
||||
def test_no_super_init_config_and_model(self):
|
||||
config = NoSuperInitConfig(attribute=32)
|
||||
model = NoSuperInitModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
model.save_pretrained(tmp_dir)
|
||||
|
||||
model = NoSuperInitModel.from_pretrained(tmp_dir)
|
||||
|
||||
|
||||
@require_torch
|
||||
@is_staging_test
|
||||
|
||||
Reference in New Issue
Block a user