Add FlaxCLIPTextModelWithProjection (#25254)

* Add FlaxClipTextModelWithProjection

This is necessary to support the Flax port of Stable Diffusion XL: fb6d705fb5/text_encoder_2/config.json (L3)

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>

* Use FlaxCLIPTextModelOutput

* make fix-copies again

* Apply suggestions from code review

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Use `return_dict` for consistency with other uses.

Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Fix docstring example.

* Add new model to FlaxCLIPTextModelTest

* Add to IGNORE_NON_AUTO_CONFIGURED list

* Fix naming convention.

---------

Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
Pedro Cuenca
2023-08-25 10:58:14 +02:00
committed by GitHub
parent 8968fface4
commit cb8e3ee25f
7 changed files with 126 additions and 2 deletions

View File

@@ -19,7 +19,12 @@ if is_flax_available():
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
)
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
from transformers.models.clip.modeling_flax_clip import (
FlaxCLIPModel,
FlaxCLIPTextModel,
FlaxCLIPTextModelWithProjection,
FlaxCLIPVisionModel,
)
if is_torch_available():
import torch
@@ -315,7 +320,7 @@ class FlaxCLIPTextModelTester:
@require_flax
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()
def setUp(self):
self.model_tester = FlaxCLIPTextModelTester(self)