Add FlaxCLIPTextModelWithProjection (#25254)
* Add FlaxClipTextModelWithProjection
This is necessary to support the Flax port of Stable Diffusion XL: fb6d705fb5/text_encoder_2/config.json (L3)
Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
* Use FlaxCLIPTextModelOutput
* make fix-copies again
* Apply suggestions from code review
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Use `return_dict` for consistency with other uses.
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
* Fix docstring example.
* Add new model to FlaxCLIPTextModelTest
* Add to IGNORE_NON_AUTO_CONFIGURED list
* Fix naming convention.
---------
Co-authored-by: Martin Müller <martin.muller.me@gmail.com>
Co-authored-by: Juan Acevedo <juancevedo@gmail.com>
Co-authored-by: Sanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
This commit is contained in:
@@ -19,7 +19,12 @@ if is_flax_available():
|
||||
convert_pytorch_state_dict_to_flax,
|
||||
load_flax_weights_in_pytorch_model,
|
||||
)
|
||||
from transformers.models.clip.modeling_flax_clip import FlaxCLIPModel, FlaxCLIPTextModel, FlaxCLIPVisionModel
|
||||
from transformers.models.clip.modeling_flax_clip import (
|
||||
FlaxCLIPModel,
|
||||
FlaxCLIPTextModel,
|
||||
FlaxCLIPTextModelWithProjection,
|
||||
FlaxCLIPVisionModel,
|
||||
)
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
@@ -315,7 +320,7 @@ class FlaxCLIPTextModelTester:
|
||||
|
||||
@require_flax
|
||||
class FlaxCLIPTextModelTest(FlaxModelTesterMixin, unittest.TestCase):
|
||||
all_model_classes = (FlaxCLIPTextModel,) if is_flax_available() else ()
|
||||
all_model_classes = (FlaxCLIPTextModel, FlaxCLIPTextModelWithProjection) if is_flax_available() else ()
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = FlaxCLIPTextModelTester(self)
|
||||
|
||||
Reference in New Issue
Block a user