Expose get_config() on ModelTesters (#12812)

* Expose get_config() on ModelTesters

* Typo
This commit is contained in:
Lysandre Debut
2021-07-21 10:13:11 +02:00
committed by GitHub
parent cabcc75171
commit c3d9ac7607
53 changed files with 1249 additions and 1193 deletions

View File

@@ -21,6 +21,7 @@ import tempfile
import unittest
import requests
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
@@ -32,7 +33,7 @@ if is_torch_available():
import torch
from torch import nn
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
@@ -77,7 +78,12 @@ class CLIPVisionModelTester:
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = CLIPVisionConfig(
config = self.get_config()
return config, pixel_values
def get_config(self):
return CLIPVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
@@ -90,8 +96,6 @@ class CLIPVisionModelTester:
initializer_range=self.initializer_range,
)
return config, pixel_values
def create_and_check_model(self, config, pixel_values):
model = CLIPVisionModel(config=config)
model.to(torch_device)
@@ -323,7 +327,12 @@ class CLIPTextModelTester:
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
config = CLIPTextConfig(
config = self.get_config()
return config, input_ids, input_mask
def get_config(self):
return CLIPTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
@@ -335,8 +344,6 @@ class CLIPTextModelTester:
initializer_range=self.initializer_range,
)
return config, input_ids, input_mask
def create_and_check_model(self, config, input_ids, input_mask):
model = CLIPTextModel(config=config)
model.to(torch_device)
@@ -409,10 +416,15 @@ class CLIPModelTester:
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = CLIPConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return CLIPConfig.from_text_vision_configs(
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = CLIPModel(config).to(torch_device).eval()
result = model(input_ids, pixel_values, attention_mask)