Add training support for SigLIP (#31495)

* Add siglip loss function

* Update docs

* Enable training tests
[experimental] enable GC training tests as it has worked for my own data

* Remove test_training* overrides to enable training tests
[run_slow] siglip

* Skip training tests for Siglip text model and ImageClassificationModel
[run_slow] siglip

* Skip GC training tests for SiglipForImageClassification

* Explicitly skip training tests for SiglipVisionModel
Add skip reason for training tests for SiglipTextModel

* Remove copied from to fix CI
This commit is contained in:
Billy Cao
2024-07-05 21:50:39 +08:00
committed by GitHub
parent 1556025271
commit 1d3eaa6f7e
3 changed files with 11 additions and 30 deletions

View File

@@ -335,27 +335,19 @@ class SiglipTextModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skip
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training(self):
pass
@unittest.skip
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(
reason="This architecure seem to not compute gradients properly when using GC, check: https://github.com/huggingface/transformers/pull/27124"
)
# Copied from tests.models.clip.test_modeling_clip.CLIPTextModelTest.test_training_gradient_checkpointing_use_reentrant_false
@unittest.skip(reason="SiglipTextModel does not support standalone training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@@ -481,22 +473,6 @@ class SiglipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
def test_model_get_set_embeddings(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant(self):
pass
@unittest.skip(reason="SiglipModel does not support training")
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(reason="Siglip uses the same initialization scheme as the Flax original implementation")
def test_initialization(self):
pass