Add SigLIP (#26522)

* Add first draft

* Use appropriate gelu function

* More improvements

* More improvements

* More improvements

* Convert checkpoint

* More improvements

* Improve docs, remove print statements

* More improvements

* Add link

* remove unused masking function

* begin tokenizer

* do_lower_case

* debug

* set split_special_tokens=True

* Remove script

* Fix style

* Fix rebase

* Use same design as CLIP

* Add fast tokenizer

* Add SiglipTokenizer to init, remove extra_ids

* Improve conversion script

* Use smaller inputs in conversion script

* Update conversion script

* More improvements

* Add processor to conversion script

* Add tests

* Remove print statements

* Add tokenizer tests

* Fix more tests

* More improvements related to weight initialization

* More improvements

* Make more tests pass

* More improvements

* More improvements

* Add copied from

* Add canonicalize_text

* Enable fast tokenizer tests

* More improvements

* Fix most slow tokenizer tests

* Address comments

* Fix style

* Remove script

* Address some comments

* Add copied from to tests

* Add more copied from

* Add more copied from

* Add more copied from

* Remove is_flax_available

* More updates

* Address comment

* Remove SiglipTokenizerFast for now

* Add caching

* Remove umt5 test

* Add canonicalize_text inside _tokenize, thanks Arthur

* Fix image processor tests

* Skip tests which are not applicable

* Skip test_initialization

* More improvements

* Compare pixel values

* Fix doc tests, add integration test

* Add do_normalize

* Remove causal mask and leverage ignore copy

* Fix attention_mask

* Fix remaining tests

* Fix dummies

* Rename temperature and bias

* Address comments

* Add copied from to tokenizer tests

* Add SiglipVisionModel to auto mapping

* Add copied from to image processor tests

* Improve doc

* Remove SiglipVisionModel from index

* Address comments

* Improve docs

* Simplify config

* Add first draft

* Make it like mistral

* More improvements

* Fix attention_mask

* Fix output_attentions

* Add note in docs

* Convert multilingual model

* Convert large checkpoint

* Convert more checkpoints

* Add pipeline support, correct image_mean and image_std

* Use padding=max_length by default

* Make processor like llava

* Add code snippet

* Convert more checkpoints

* Set keep_punctuation_string=None as in OpenCLIP

* Set normalized=False for special tokens

* Fix doc test

* Update integration test

* Add figure

* Update organization

* Happy new year

* Use AutoModel everywhere

---------

Co-authored-by: patil-suraj <surajp815@gmail.com>
This commit is contained in:
NielsRogge
2024-01-08 18:17:16 +01:00
committed by GitHub
parent 73c88012b7
commit 3b742ea84c
35 changed files with 4254 additions and 4 deletions

View File

View File

@@ -0,0 +1,125 @@
# coding=utf-8
# Copyright 2024 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from transformers import SiglipImageProcessor
class SiglipImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
image_size=18,
min_resolution=30,
max_resolution=400,
do_resize=True,
size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
):
size = size if size is not None else {"height": 18, "width": 18}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.image_size = image_size
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"image_mean": self.image_mean,
"image_std": self.image_std,
}
def expected_output_image_shape(self, images):
return self.num_channels, self.size["height"], self.size["width"]
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
# Copied from tests.models.clip.test_image_processing_clip.CLIPImageProcessingTest with CLIP->Siglip
class SiglipImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = SiglipImageProcessor if is_vision_available() else None
def setUp(self):
self.image_processor_tester = SiglipImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
# Ignore copy
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
# Ignore copy
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 18, "width": 18})
image_processor = self.image_processing_class.from_dict(
self.image_processor_dict, size={"height": 84, "width": 84}
)
self.assertEqual(image_processor.size, {"height": 84, "width": 84})
@unittest.skip("not supported")
# Ignore copy
def test_call_numpy_4_channels(self):
pass

View File

@@ -0,0 +1,631 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Testing suite for the PyTorch Siglip model. """
import inspect
import os
import tempfile
import unittest
import numpy as np
import requests
from transformers import SiglipConfig, SiglipTextConfig, SiglipVisionConfig
from transformers.testing_utils import (
require_torch,
require_vision,
slow,
torch_device,
)
from transformers.utils import is_torch_available, is_vision_available
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import (
ModelTesterMixin,
_config_zero_init,
floats_tensor,
ids_tensor,
random_attention_mask,
)
from ...test_pipeline_mixin import PipelineTesterMixin
if is_torch_available():
import torch
from torch import nn
from transformers import SiglipModel, SiglipTextModel, SiglipVisionModel
from transformers.models.siglip.modeling_siglip import SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST
if is_vision_available():
from PIL import Image
from transformers import SiglipProcessor
class SiglipVisionModelTester:
def __init__(
self,
parent,
batch_size=12,
image_size=30,
patch_size=2,
num_channels=3,
is_training=True,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.patch_size = patch_size
self.num_channels = num_channels
self.is_training = is_training
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.initializer_range = initializer_range
self.scope = scope
# in ViT, the seq length equals the number of patches
num_patches = (image_size // patch_size) ** 2
self.seq_length = num_patches
# Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
config = self.get_config()
return config, pixel_values
def get_config(self):
return SiglipVisionConfig(
image_size=self.image_size,
patch_size=self.patch_size,
num_channels=self.num_channels,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, pixel_values):
model = SiglipVisionModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(pixel_values)
# expected sequence length = num_patches + 1 (we add 1 for the [CLS] token)
image_size = (self.image_size, self.image_size)
patch_size = (self.patch_size, self.patch_size)
num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, num_patches, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
# Copied from tests.models.clip.test_modeling_clip.CLIPVisionModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class SiglipVisionModelTest(ModelTesterMixin, unittest.TestCase):
"""
Here we also overwrite some of the tests of test_modeling_common.py, as SIGLIP does not use input_ids, inputs_embeds,
attention_mask and seq_length.
"""
all_model_classes = (SiglipVisionModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_resize_embeddings = False
test_head_masking = False
def setUp(self):
self.model_tester = SiglipVisionModelTester(self)
self.config_tester = ConfigTester(
self, config_class=SiglipVisionConfig, has_text_modality=False, hidden_size=37
)
def test_config(self):
self.config_tester.run_common_tests()
@unittest.skip(reason="SIGLIP does not use inputs_embeds")
def test_inputs_embeds(self):
pass
def test_model_common_attributes(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
self.assertIsInstance(model.get_input_embeddings(), (nn.Module))
x = model.get_output_embeddings()
self.assertTrue(x is None or isinstance(x, nn.Linear))
def test_forward_signature(self):
config, _ = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
model = model_class(config)
signature = inspect.signature(model.forward)
# signature.parameters is an OrderedDict => so arg_names order is deterministic
arg_names = [*signature.parameters.keys()]
expected_arg_names = ["pixel_values"]
self.assertListEqual(arg_names[:1], expected_arg_names)
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipVisionModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="SiglipVisionModel has no base class and is not available in MODEL_MAPPING")
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SiglipVisionModel.from_pretrained(model_name)
self.assertIsNotNone(model)
class SiglipTextModelTester:
def __init__(
self,
parent,
batch_size=12,
seq_length=7,
is_training=True,
use_input_mask=True,
use_labels=True,
vocab_size=99,
hidden_size=32,
num_hidden_layers=2,
num_attention_heads=4,
intermediate_size=37,
dropout=0.1,
attention_dropout=0.1,
max_position_embeddings=512,
initializer_range=0.02,
scope=None,
):
self.parent = parent
self.batch_size = batch_size
self.seq_length = seq_length
self.is_training = is_training
self.use_input_mask = use_input_mask
self.use_labels = use_labels
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.intermediate_size = intermediate_size
self.dropout = dropout
self.attention_dropout = attention_dropout
self.max_position_embeddings = max_position_embeddings
self.initializer_range = initializer_range
self.scope = scope
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size)
input_mask = None
if self.use_input_mask:
input_mask = random_attention_mask([self.batch_size, self.seq_length])
if input_mask is not None:
batch_size, seq_length = input_mask.shape
rnd_start_indices = np.random.randint(1, seq_length - 1, size=(batch_size,))
for batch_idx, start_index in enumerate(rnd_start_indices):
input_mask[batch_idx, :start_index] = 1
input_mask[batch_idx, start_index:] = 0
config = self.get_config()
return config, input_ids, input_mask
def get_config(self):
return SiglipTextConfig(
vocab_size=self.vocab_size,
hidden_size=self.hidden_size,
num_hidden_layers=self.num_hidden_layers,
num_attention_heads=self.num_attention_heads,
intermediate_size=self.intermediate_size,
dropout=self.dropout,
attention_dropout=self.attention_dropout,
max_position_embeddings=self.max_position_embeddings,
initializer_range=self.initializer_range,
)
def create_and_check_model(self, config, input_ids, input_mask):
model = SiglipTextModel(config=config)
model.to(torch_device)
model.eval()
with torch.no_grad():
result = model(input_ids, attention_mask=input_mask)
result = model(input_ids)
self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTester.prepare_config_and_inputs_for_common
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, input_mask = config_and_inputs
inputs_dict = {"input_ids": input_ids, "attention_mask": input_mask}
return config, inputs_dict
@require_torch
class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (SiglipTextModel,) if is_torch_available() else ()
fx_compatible = False
test_pruning = False
test_head_masking = False
model_split_percents = [0.5, 0.8, 0.9]
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.setUp with CLIP->Siglip
def setUp(self):
self.model_tester = SiglipTextModelTester(self)
self.config_tester = ConfigTester(self, config_class=SiglipTextConfig, hidden_size=37)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_config
def test_config(self):
self.config_tester.run_common_tests()
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training
def test_training(self):
pass
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip does not use inputs_embeds")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_inputs_embeds
def test_inputs_embeds(self):
pass
@unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_from_base
def test_save_load_fast_init_from_base(self):
pass
@unittest.skip(reason="SiglipTextModel has no base class and is not available in MODEL_MAPPING")
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_save_load_fast_init_to_base
def test_save_load_fast_init_to_base(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
@slow
def test_model_from_pretrained(self):
for model_name in SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SiglipTextModel.from_pretrained(model_name)
self.assertIsNotNone(model)
class SiglipModelTester:
def __init__(self, parent, text_kwargs=None, vision_kwargs=None, is_training=True):
if text_kwargs is None:
text_kwargs = {}
if vision_kwargs is None:
vision_kwargs = {}
self.parent = parent
self.text_model_tester = SiglipTextModelTester(parent, **text_kwargs)
self.vision_model_tester = SiglipVisionModelTester(parent, **vision_kwargs)
self.is_training = is_training
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTester.prepare_config_and_inputs
def prepare_config_and_inputs(self):
text_config, input_ids, attention_mask = self.text_model_tester.prepare_config_and_inputs()
vision_config, pixel_values = self.vision_model_tester.prepare_config_and_inputs()
config = self.get_config()
return config, input_ids, attention_mask, pixel_values
def get_config(self):
return SiglipConfig.from_text_vision_configs(
self.text_model_tester.get_config(),
self.vision_model_tester.get_config(),
)
def create_and_check_model(self, config, input_ids, attention_mask, pixel_values):
model = SiglipModel(config).to(torch_device).eval()
with torch.no_grad():
result = model(input_ids, pixel_values, attention_mask)
self.parent.assertEqual(
result.logits_per_image.shape, (self.vision_model_tester.batch_size, self.text_model_tester.batch_size)
)
self.parent.assertEqual(
result.logits_per_text.shape, (self.text_model_tester.batch_size, self.vision_model_tester.batch_size)
)
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, input_ids, attention_mask, pixel_values = config_and_inputs
inputs_dict = {
"input_ids": input_ids,
"attention_mask": attention_mask,
"pixel_values": pixel_values,
"return_loss": False,
}
return config, inputs_dict
@require_torch
class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (SiglipModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": SiglipModel} if is_torch_available() else {}
fx_compatible = False
test_head_masking = False
test_pruning = False
test_resize_embeddings = False
test_attention_outputs = False
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.setUp with CLIP->Siglip
def setUp(self):
self.model_tester = SiglipModelTester(self)
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model
def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip(reason="Hidden_states is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_hidden_states_output
def test_hidden_states_output(self):
pass
@unittest.skip(reason="Inputs_embeds is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_inputs_embeds
def test_inputs_embeds(self):
pass
@unittest.skip(reason="Retain_grad is tested in individual model tests")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_retain_grad_hidden_states_attentions
def test_retain_grad_hidden_states_attentions(self):
pass
@unittest.skip(reason="SiglipModel does not have input/output embeddings")
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model_common_attributes
def test_model_common_attributes(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest._create_and_check_torchscript with CLIP->Siglip
def _create_and_check_torchscript(self, config, inputs_dict):
if not self.test_torchscript:
return
configs_no_init = _config_zero_init(config) # To be sure we have no Nan
configs_no_init.torchscript = True
configs_no_init.return_dict = False
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
model.to(torch_device)
model.eval()
try:
input_ids = inputs_dict["input_ids"]
pixel_values = inputs_dict["pixel_values"] # Siglip needs pixel_values
traced_model = torch.jit.trace(model, (input_ids, pixel_values))
except RuntimeError:
self.fail("Couldn't trace module.")
with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")
try:
torch.jit.save(traced_model, pt_file_name)
except Exception:
self.fail("Couldn't save module.")
try:
loaded_model = torch.jit.load(pt_file_name)
except Exception:
self.fail("Couldn't load module.")
model.to(torch_device)
model.eval()
loaded_model.to(torch_device)
loaded_model.eval()
model_state_dict = model.state_dict()
loaded_model_state_dict = loaded_model.state_dict()
non_persistent_buffers = {}
for key in loaded_model_state_dict.keys():
if key not in model_state_dict.keys():
non_persistent_buffers[key] = loaded_model_state_dict[key]
loaded_model_state_dict = {
key: value for key, value in loaded_model_state_dict.items() if key not in non_persistent_buffers
}
self.assertEqual(set(model_state_dict.keys()), set(loaded_model_state_dict.keys()))
model_buffers = list(model.buffers())
for non_persistent_buffer in non_persistent_buffers.values():
found_buffer = False
for i, model_buffer in enumerate(model_buffers):
if torch.equal(non_persistent_buffer, model_buffer):
found_buffer = True
break
self.assertTrue(found_buffer)
model_buffers.pop(i)
models_equal = True
for layer_name, p1 in model_state_dict.items():
p2 = loaded_model_state_dict[layer_name]
if p1.data.ne(p2.data).sum() > 0:
models_equal = False
self.assertTrue(models_equal)
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_load_vision_text_config with CLIP->Siglip
def test_load_vision_text_config(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
# Save SiglipConfig and check if we can load SiglipVisionConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
vision_config = SiglipVisionConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.vision_config.to_dict(), vision_config.to_dict())
# Save SiglipConfig and check if we can load SiglipTextConfig from it
with tempfile.TemporaryDirectory() as tmp_dir_name:
config.save_pretrained(tmp_dir_name)
text_config = SiglipTextConfig.from_pretrained(tmp_dir_name)
self.assertDictEqual(config.text_config.to_dict(), text_config.to_dict())
@slow
# Copied from tests.models.clip.test_modeling_clip.CLIPModelTest.test_model_from_pretrained with CLIPModel->SiglipModel, CLIP->SIGLIP
def test_model_from_pretrained(self):
for model_name in SIGLIP_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
model = SiglipModel.from_pretrained(model_name)
self.assertIsNotNone(model)
# We will verify our results on an image of cute cats
def prepare_img():
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image
@require_vision
@require_torch
class SiglipModelIntegrationTest(unittest.TestCase):
@slow
def test_inference(self):
model_name = "google/siglip-base-patch16-224"
model = SiglipModel.from_pretrained(model_name).to(torch_device)
processor = SiglipProcessor.from_pretrained(model_name)
image = prepare_img()
inputs = processor(
text=["a photo of 2 cats", "a photo of 2 dogs"], images=image, padding="max_length", return_tensors="pt"
).to(torch_device)
# forward pass
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
logits_per_text = outputs.logits_per_text
# verify the logits
self.assertEqual(
logits_per_image.shape,
torch.Size((inputs.pixel_values.shape[0], inputs.input_ids.shape[0])),
)
self.assertEqual(
logits_per_text.shape,
torch.Size((inputs.input_ids.shape[0], inputs.pixel_values.shape[0])),
)
expected_logits = torch.tensor([[-0.7567, -10.3354]], device=torch_device)
self.assertTrue(torch.allclose(outputs.logits_per_image, expected_logits, atol=1e-3))
# verify the probs
probs = torch.sigmoid(logits_per_image) # these are the probabilities
expected_probs = torch.tensor([[3.1937e-01, 3.2463e-05]], device=torch_device)
self.assertTrue(torch.allclose(probs, expected_probs, atol=1e-3))

View File

@@ -0,0 +1,462 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
import tempfile
import unittest
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, SiglipTokenizer
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers, slow
from transformers.utils import cached_property, is_tf_available, is_torch_available
from ...test_tokenization_common import TokenizerTesterMixin
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
if is_torch_available():
FRAMEWORK = "pt"
elif is_tf_available():
FRAMEWORK = "tf"
else:
FRAMEWORK = "jax"
@require_sentencepiece
@require_tokenizers
class SiglipTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = SiglipTokenizer
test_rust_tokenizer = False
test_sentencepiece = True
test_sentencepiece_ignore_case = True
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.setUp with T5->Siglip
def setUp(self):
super().setUp()
# We have a SentencePiece fixture for testing
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_convert_token_and_id with T5->Siglip
def test_convert_token_and_id(self):
"""Test ``_convert_token_to_id`` and ``_convert_id_to_token``."""
token = "<s>"
token_id = 1
self.assertEqual(self.get_tokenizer()._convert_token_to_id(token), token_id)
self.assertEqual(self.get_tokenizer()._convert_id_to_token(token_id), token)
def test_get_vocab(self):
vocab_keys = list(self.get_tokenizer().get_vocab().keys())
self.assertEqual(vocab_keys[0], "<unk>")
self.assertEqual(vocab_keys[1], "<s>")
def test_full_tokenizer(self):
tokenizer = SiglipTokenizer(SAMPLE_VOCAB)
tokens = tokenizer.tokenize("This is a test")
self.assertListEqual(tokens, ["▁this", "▁is", "▁a", "▁t", "est"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [66, 46, 10, 170, 382])
tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
self.assertListEqual(
tokens,
[
SPIECE_UNDERLINE,
"i",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"9",
"2",
"0",
"0",
"0",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"é",
],
)
ids = tokenizer.convert_tokens_to_ids(tokens)
self.assertListEqual(ids, [7, 23, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 12, 66, 46, 72, 80, 6, 0])
back_tokens = tokenizer.convert_ids_to_tokens(ids)
self.assertListEqual(
back_tokens,
[
SPIECE_UNDERLINE,
"i",
SPIECE_UNDERLINE + "was",
SPIECE_UNDERLINE + "b",
"or",
"n",
SPIECE_UNDERLINE + "in",
SPIECE_UNDERLINE + "",
"<unk>",
"2",
"0",
"0",
"0",
SPIECE_UNDERLINE + "and",
SPIECE_UNDERLINE + "this",
SPIECE_UNDERLINE + "is",
SPIECE_UNDERLINE + "f",
"al",
"s",
"<unk>",
],
)
@cached_property
def siglip_tokenizer(self):
return SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224")
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.get_tokenizer with T5->Siglip
def get_tokenizer(self, **kwargs) -> SiglipTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_rust_and_python_full_tokenizers with T5->Siglip
def test_rust_and_python_full_tokenizers(self):
if not self.test_rust_tokenizer:
return
tokenizer = self.get_tokenizer()
rust_tokenizer = self.get_rust_tokenizer()
sequence = "I was born in 92000, and this is falsé."
tokens = tokenizer.tokenize(sequence)
rust_tokens = rust_tokenizer.tokenize(sequence)
self.assertListEqual(tokens, rust_tokens)
ids = tokenizer.encode(sequence, add_special_tokens=False)
rust_ids = rust_tokenizer.encode(sequence, add_special_tokens=False)
self.assertListEqual(ids, rust_ids)
rust_tokenizer = self.get_rust_tokenizer()
ids = tokenizer.encode(sequence)
rust_ids = rust_tokenizer.encode(sequence)
self.assertListEqual(ids, rust_ids)
def test_eos_treatment(self):
tokenizer = self.siglip_tokenizer
batch_with_eos_added = tokenizer(["hi</s>", "I went to the gym</s>", "</s>"])
batch_without_eos_added = tokenizer(["hi", "I went to the gym", ""])
self.assertListEqual(batch_with_eos_added["input_ids"], batch_without_eos_added["input_ids"])
def test_prepare_batch(self):
tokenizer = self.siglip_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, tokenizer.eos_token_id]
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
self.assertIsInstance(batch, BatchEncoding)
if FRAMEWORK != "jax":
result = list(batch.input_ids.numpy()[0])
else:
result = list(batch.input_ids.tolist()[0])
self.assertListEqual(expected_src_tokens, result)
self.assertEqual((2, 9), batch.input_ids.shape)
def test_empty_target_text(self):
tokenizer = self.siglip_tokenizer
src_text = ["A long paragraph for summarization.", "Another paragraph for summarization."]
batch = tokenizer(src_text, padding=True, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch)
self.assertNotIn("decoder_input_ids", batch)
self.assertNotIn("decoder_attention_mask", batch)
def test_max_length(self):
tokenizer = self.siglip_tokenizer
tgt_text = ["Summary of the text.", "Another summary."]
targets = tokenizer(
text_target=tgt_text, max_length=32, padding="max_length", truncation=True, return_tensors=FRAMEWORK
)
self.assertEqual(32, targets["input_ids"].shape[1])
def test_eos_in_input(self):
tokenizer = self.siglip_tokenizer
src_text = ["A long paragraph for summarization. </s>"]
tgt_text = ["Summary of the text. </s>"]
expected_src_tokens = [262, 266, 476, 8532, 270, 4460, 3949, 1682, 1]
expected_tgt_tokens = [6254, 267, 260, 1443, 1]
batch = tokenizer(src_text, text_target=tgt_text)
self.assertEqual(expected_src_tokens, batch["input_ids"][0])
self.assertEqual(expected_tgt_tokens, batch["labels"][0])
@unittest.skip(reason="SiglipTokenizer strips the punctuation")
def test_subword_regularization_tokenizer(self):
pass
@unittest.skip(reason="SiglipTokenizer strips the punctuation")
def test_pickle_subword_regularization_tokenizer(self):
pass
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization with T5->Siglip
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
r_output = tokenizer_r.encode("Hey this is a <special> token")
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in r_output)
self.assertTrue(special_token_id in cr_output)
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_special_tokens_initialization_with_non_empty_additional_special_tokens with T5->Siglip
def test_special_tokens_initialization_with_non_empty_additional_special_tokens(self):
tokenizer_list = []
if self.test_slow_tokenizer:
tokenizer_list.append((self.tokenizer_class, self.get_tokenizer()))
if self.test_rust_tokenizer:
tokenizer_list.append((self.rust_tokenizer_class, self.get_rust_tokenizer()))
for tokenizer_class, tokenizer_utils in tokenizer_list:
with tempfile.TemporaryDirectory() as tmp_dir:
tokenizer_utils.save_pretrained(tmp_dir)
with open(os.path.join(tmp_dir, "special_tokens_map.json"), encoding="utf-8") as json_file:
special_tokens_map = json.load(json_file)
with open(os.path.join(tmp_dir, "tokenizer_config.json"), encoding="utf-8") as json_file:
tokenizer_config = json.load(json_file)
added_tokens_extra_ids = [f"<extra_id_{i}>" for i in range(100)]
special_tokens_map["additional_special_tokens"] = added_tokens_extra_ids + [
"an_additional_special_token"
]
tokenizer_config["additional_special_tokens"] = added_tokens_extra_ids + [
"an_additional_special_token"
]
with open(os.path.join(tmp_dir, "special_tokens_map.json"), "w", encoding="utf-8") as outfile:
json.dump(special_tokens_map, outfile)
with open(os.path.join(tmp_dir, "tokenizer_config.json"), "w", encoding="utf-8") as outfile:
json.dump(tokenizer_config, outfile)
# the following checks allow us to verify that our test works as expected, i.e. that the tokenizer takes
# into account the new value of additional_special_tokens given in the "tokenizer_config.json" and
# "special_tokens_map.json" files
tokenizer_without_change_in_init = tokenizer_class.from_pretrained(
tmp_dir,
)
self.assertIn(
"an_additional_special_token", tokenizer_without_change_in_init.additional_special_tokens
)
# self.assertIn("an_additional_special_token",tokenizer_without_change_in_init.get_vocab()) # BySiglipTokenization no vocab
self.assertEqual(
["an_additional_special_token"],
tokenizer_without_change_in_init.convert_ids_to_tokens(
tokenizer_without_change_in_init.convert_tokens_to_ids(["an_additional_special_token"])
),
)
# Now we test that we can change the value of additional_special_tokens in the from_pretrained
new_added_tokens = added_tokens_extra_ids + [AddedToken("a_new_additional_special_token", lstrip=True)]
tokenizer = tokenizer_class.from_pretrained(
tmp_dir,
additional_special_tokens=new_added_tokens,
)
self.assertIn("a_new_additional_special_token", tokenizer.additional_special_tokens)
self.assertEqual(
["a_new_additional_special_token"],
tokenizer.convert_ids_to_tokens(
tokenizer.convert_tokens_to_ids(["a_new_additional_special_token"])
),
)
def test_sentencepiece_tokenize_and_convert_tokens_to_string(self):
"""Test ``_tokenize`` and ``convert_tokens_to_string``."""
if not self.test_sentencepiece:
return
tokenizer = self.get_tokenizer()
text = "This is text to test the tokenizer."
if self.test_sentencepiece_ignore_case:
text = text.lower()
tokens = tokenizer.tokenize(text)
self.assertTrue(len(tokens) > 0)
# check if converting back to original text works
reverse_text = tokenizer.convert_tokens_to_string(tokens)
if self.test_sentencepiece_ignore_case:
reverse_text = reverse_text.lower()
expected_text = "this is text to test the tokenizer"
self.assertEqual(reverse_text, expected_text)
special_tokens = tokenizer.all_special_tokens
special_tokens_string = tokenizer.convert_tokens_to_string(special_tokens)
for special_token in special_tokens:
self.assertIn(special_token, special_tokens_string)
if self.test_rust_tokenizer:
rust_tokenizer = self.get_rust_tokenizer()
special_tokens_string_rust = rust_tokenizer.convert_tokens_to_string(special_tokens)
self.assertEqual(special_tokens_string, special_tokens_string_rust)
# overwritten from `test_tokenization_common` since Siglip has no max length
# Copied from tests.models.t5.test_tokenization_t5.T5TokenizationTest.test_pretrained_model_lists with T5->Siglip
def test_pretrained_model_lists(self):
# We should have at least one default checkpoint for each tokenizer
# We should specify the max input length as well (used in some part to list the pretrained checkpoints)
self.assertGreaterEqual(len(self.tokenizer_class.pretrained_vocab_files_map), 1)
self.assertGreaterEqual(len(list(self.tokenizer_class.pretrained_vocab_files_map.values())[0]), 1)
@slow
def test_tokenizer_integration(self):
tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224")
# fmt: off
texts = [
'the real mountain view',
'Zürich',
'San Francisco',
'a picture of a laptop with the lockscreen on, a cup of cappucino, salt and pepper grinders. The view through the window reveals lake Zürich and the Alps in the background of the city.',
]
expected_input_ids = [
[260, 638, 3293, 870, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[262, 761, 5879, 5345, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[262, 264, 452, 20563, 15949, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[262, 266, 1357, 267, 262, 266, 4429, 275, 260, 3940, 6360, 277, 262, 266, 3064, 267, 3549, 388, 16538, 296, 298, 2617, 263, 4869, 14998, 264, 260, 870, 393, 260, 1710, 7958, 4324, 262, 761, 5879, 5345, 263, 260, 1518, 388, 264, 268, 260, 1970, 267, 260, 741, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
]
# fmt: on
for text, expected in zip(texts, expected_input_ids):
input_ids = tokenizer(text, padding="max_length").input_ids
self.assertListEqual(input_ids, expected)
def test_some_edge_cases(self):
tokenizer = SiglipTokenizer.from_pretrained("google/siglip-base-patch16-224", legacy=False)
sp_tokens = tokenizer.sp_model.encode("</s>>", out_type=str)
self.assertEqual(sp_tokens, ["</", "s", ">", ">"])
tokens = tokenizer.tokenize("</s>>")
self.assertNotEqual(sp_tokens, tokens)
self.assertEqual(tokens, ["</s>"])
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
tokens = tokenizer.tokenize(" ")
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode(" ", out_type=str))
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
tokens = tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, tokenizer.sp_model.encode("", out_type=str))
@require_sentencepiece
@require_tokenizers
class CommonSpmIntegrationTests(unittest.TestCase):
"""
A class that regroups important test to make sure that we properly handle the special tokens.
"""
@classmethod
def setUpClass(cls):
tokenizer = SiglipTokenizer(SAMPLE_VOCAB, extra_ids=0, legacy=False)
tokenizer.add_special_tokens(
{"additional_special_tokens": [AddedToken("<extra_id_0>", rstrip=False, lstrip=False)]}
)
cls.tokenizer = tokenizer
def test_add_dummy_prefix(self):
# make sure `'▁'` is prepended, and outputs match sp_model's
# `sentencepiece.NormalizerSpec.add_dummy_prefix` attribute
input_ids = self.tokenizer.encode(". Hello", add_special_tokens=False)
self.assertEqual(input_ids, [37, 86, 20])
self.assertEqual(input_ids, [37, 86, 20])
tokens = self.tokenizer.tokenize(". Hello")
self.assertEqual(tokens, ["▁he", "ll", "o"])
tokens = self.tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str))
tokens = self.tokenizer.tokenize(" ")
self.assertEqual(tokens, [])
self.assertEqual(tokens, self.tokenizer.sp_model.encode(" ", out_type=str))
tokens = self.tokenizer.tokenize("")
self.assertEqual(tokens, [])
self.assertEqual(tokens, self.tokenizer.sp_model.encode("", out_type=str))
def test_remove_extra_whitespaces(self):
# make sure the extra spaces are eaten
# sentencepiece.NormalizerSpec.remove_extra_whitespaces attribute
input_ids = self.tokenizer.encode(" . Hello", add_special_tokens=False)
self.assertEqual(input_ids, [37, 86, 20])
self.assertEqual(input_ids, [37, 86, 20])
tokens = self.tokenizer.tokenize(" . Hello")
self.assertEqual(tokens, ["▁he", "ll", "o"])
# `'▁'` is also a whitespace
input_ids = self.tokenizer.encode("▁He is not")
self.assertEqual(input_ids, [37, 46, 44, 2])
tokens = self.tokenizer.tokenize("▁He is not")
self.assertEqual(tokens, ["▁he", "▁is", "▁not"]) # no extra space added
input_ids = self.tokenizer.encode("▁He is not ▁He")
self.assertEqual(input_ids, [37, 46, 44, 37, 2])
tokens = self.tokenizer.tokenize("▁He is not ▁He")
self.assertEqual(tokens, ["▁he", "▁is", "▁not", "▁he"]) # spaces are eaten by spm even if not start

View File

@@ -241,3 +241,37 @@ class ZeroShotImageClassificationPipelineTests(unittest.TestCase):
]
* 5,
)
@slow
@require_torch
def test_siglip_model_pt(self):
image_classifier = pipeline(
task="zero-shot-image-classification",
model="google/siglip-base-patch16-224",
)
# This is an image of 2 cats with remotes and no planes
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
output = image_classifier(image, candidate_labels=["2 cats", "a plane", "a remote"])
self.assertEqual(
nested_simplify(output),
[
{"score": 0.198, "label": "2 cats"},
{"score": 0.0, "label": "a remote"},
{"score": 0.0, "label": "a plane"},
],
)
output = image_classifier([image] * 5, candidate_labels=["2 cats", "a plane", "a remote"], batch_size=2)
self.assertEqual(
nested_simplify(output),
[
[
{"score": 0.198, "label": "2 cats"},
{"score": 0.0, "label": "a remote"},
{"score": 0.0, "label": "a plane"},
]
]
* 5,
)