Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
This commit is contained in:
Lysandre Debut
2021-07-21 10:13:11 +02:00
committed by GitHub
parent cabcc75171
commit c3d9ac7607
53 changed files with 1249 additions and 1193 deletions

View File

@@ -18,6 +18,7 @@
import inspect
import unittest
from transformers import ViTConfig
from transformers.file_utils import cached_property, is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
@@ -29,7 +30,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import ViTConfig, ViTForImageClassification, ViTModel
from transformers import ViTForImageClassification, ViTModel
from transformers.models.vit.modeling_vit import VIT_PRETRAINED_MODEL_ARCHIVE_LIST, to_2tuple
@@ -86,7 +87,12 @@ class ViTModelTester:
if self.use_labels:
labels = ids_tensor([self.batch_size], self.type_sequence_label_size)
config = ViTConfig(
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return ViTConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
@@ -101,8 +107,6 @@ class ViTModelTester:
initializer_range=self.initializer_range,
)
return config, pixel_values, labels
def create_and_check_model(self, config, pixel_values, labels):
model = ViTModel(config=config)
model.to(torch_device)