Expose get_config() on ModelTesters (#12812)
* Expose get_config() on ModelTesters * Typo
This commit is contained in:
@@ -21,6 +21,7 @@ import tempfile
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
||||
from transformers.file_utils import is_torch_available, is_vision_available
|
||||
from transformers.testing_utils import require_torch, require_vision, slow, torch_device
|
||||
|
||||
@@ -32,7 +33,7 @@ if is_torch_available():
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import CLIPConfig, CLIPModel, CLIPTextConfig, CLIPTextModel, CLIPVisionConfig, CLIPVisionModel
|
||||
from transformers import CLIPModel, CLIPTextModel, CLIPVisionModel
|
||||
from transformers.models.clip.modeling_clip import CLIP_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
|
||||
|
||||
@@ -77,7 +78,12 @@ class CLIPVisionModelTester:
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
|
||||
config = CLIPVisionConfig(
|
||||
config = self.get_config()
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return CLIPVisionConfig(
|
||||
image_size=self.image_size,
|
||||
patch_size=self.patch_size,
|
||||
num_channels=self.num_channels,
|
||||
@@ -90,8 +96,6 @@ class CLIPVisionModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, pixel_values
|
||||
|
||||
def create_and_check_model(self, config, pixel_values):
|
||||
model = CLIPVisionModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -323,7 +327,12 @@ class CLIPTextModelTester:
|
||||
if self.use_input_mask:
|
||||
input_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
|
||||
config = CLIPTextConfig(
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def get_config(self):
|
||||
return CLIPTextConfig(
|
||||
vocab_size=self.vocab_size,
|
||||
hidden_size=self.hidden_size,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
@@ -335,8 +344,6 @@ class CLIPTextModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
)
|
||||
|
||||
return config, input_ids, input_mask
|
||||
|
||||
def create_and_check_model(self, config, input_ids, input_mask):
|
||||
model = CLIPTextModel(config=config)
|
||||
model.to(torch_device)
|
||||
@@ -409,10 +416,15 @@ class CLIPModelTester:
|
||||
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
|
||||
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
|
||||
|
||||
config = CLIPConfig.from_text_vision_configs(text_config, vision_config, projection_dim=64)
|
||||
config = self.get_config()
|
||||
|
||||
return config, input_ids, attention_mask, pixel_values
|
||||
|
||||
def get_config(self):
|
||||
return CLIPConfig.from_text_vision_configs(
|
||||
self.text_model_tester.get_config(), self.vision_model_tester.get_config(), projection_dim=64
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
|
||||
model = CLIPModel(config).to(torch_device).eval()
|
||||
result = model(input_ids, pixel_values, attention_mask)
|
||||
|
||||
Reference in New Issue
Block a user