[Flax] Fix hybrid clip (#12519)
* fix saving and loading * update readme
This commit is contained in:
@@ -68,6 +68,34 @@ export MODEL_DIR="./clip-roberta-base
|
|||||||
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
|
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## How to use the `FlaxHybridCLIP` model:
|
||||||
|
|
||||||
|
The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
|
||||||
|
Here is an example of how to load the model using pre-trained text and vision models.
|
||||||
|
|
||||||
|
```python
|
||||||
|
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||||
|
|
||||||
|
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32")
|
||||||
|
|
||||||
|
# save the model
|
||||||
|
model.save_pretrained("bert-clip")
|
||||||
|
|
||||||
|
# load the saved model
|
||||||
|
model = FlaxHybridCLIP.from_pretrained("bert-clip")
|
||||||
|
```
|
||||||
|
|
||||||
|
If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
|
||||||
|
PyTorch checkpoints convert them to flax and load the model.
|
||||||
|
|
||||||
|
```python
|
||||||
|
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
|
||||||
|
```
|
||||||
|
|
||||||
|
This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
|
||||||
|
initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
|
||||||
|
loaded using the pre-trained weights.
|
||||||
|
|
||||||
## Prepare the dataset
|
## Prepare the dataset
|
||||||
|
|
||||||
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
|
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
|
||||||
@@ -124,7 +152,7 @@ with open("coco_dataset/valid_dataset.json", "w") as f:
|
|||||||
Next we can run the example script to train the model:
|
Next we can run the example script to train the model:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python run_clip.py \
|
python run_hybrid_clip.py \
|
||||||
--output_dir ${MODEL_DIR} \
|
--output_dir ${MODEL_DIR} \
|
||||||
--text_model_name_or_path="roberta-base" \
|
--text_model_name_or_path="roberta-base" \
|
||||||
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
||||||
|
|||||||
@@ -25,31 +25,58 @@ class HybridCLIPConfig(PretrainedConfig):
|
|||||||
Dimentionality of text and vision projection layers.
|
Dimentionality of text and vision projection layers.
|
||||||
kwargs (`optional`):
|
kwargs (`optional`):
|
||||||
Dictionary of keyword arguments.
|
Dictionary of keyword arguments.
|
||||||
|
|
||||||
|
Examples::
|
||||||
|
|
||||||
|
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
|
||||||
|
|
||||||
|
>>> # Initializing a BERT and CLIP configuration
|
||||||
|
>>> config_text = BertConfig()
|
||||||
|
>>> config_vision = CLIPConfig()
|
||||||
|
|
||||||
|
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
|
||||||
|
|
||||||
|
>>> # Initializing a BERT and CLIPVision model
|
||||||
|
>>> model = EncoderDecoderModel(config=config)
|
||||||
|
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> config_text = model.config.text_config
|
||||||
|
>>> config_vision = model.config.vision_config
|
||||||
|
|
||||||
|
>>> # Saving the model, including its configuration
|
||||||
|
>>> model.save_pretrained('my-model')
|
||||||
|
|
||||||
|
>>> # loading model and config from pretrained folder
|
||||||
|
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
|
||||||
|
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_type = "hybrid-clip"
|
model_type = "hybrid-clip"
|
||||||
is_composition = True
|
is_composition = True
|
||||||
|
|
||||||
def __init__(self, text_config_dict, vision_config_dict, projection_dim=512, **kwargs):
|
def __init__(self, projection_dim=512, **kwargs):
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
if text_config_dict is None:
|
if "text_config" not in kwargs:
|
||||||
raise ValueError("`text_config_dict` can not be `None`.")
|
raise ValueError("`text_config` can not be `None`.")
|
||||||
|
|
||||||
if vision_config_dict is None:
|
if "vision_config" not in kwargs:
|
||||||
raise ValueError("`vision_config_dict` can not be `None`.")
|
raise ValueError("`vision_config` can not be `None`.")
|
||||||
|
|
||||||
text_model_type = text_config_dict.pop("model_type")
|
text_config = kwargs.pop("text_config")
|
||||||
vision_model_type = vision_config_dict.pop("model_type")
|
vision_config = kwargs.pop("vision_config")
|
||||||
|
|
||||||
|
text_model_type = text_config.pop("model_type")
|
||||||
|
vision_model_type = vision_config.pop("model_type")
|
||||||
|
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config_dict)
|
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
||||||
|
|
||||||
if vision_model_type == "clip":
|
if vision_model_type == "clip":
|
||||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict).vision_config
|
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
||||||
else:
|
else:
|
||||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict)
|
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
||||||
|
|
||||||
self.projection_dim = projection_dim
|
self.projection_dim = projection_dim
|
||||||
self.initializer_factor = 1.0
|
self.initializer_factor = 1.0
|
||||||
@@ -64,7 +91,7 @@ class HybridCLIPConfig(PretrainedConfig):
|
|||||||
:class:`HybridCLIPConfig`: An instance of a configuration object
|
:class:`HybridCLIPConfig`: An instance of a configuration object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
|
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -123,7 +123,7 @@ class FlaxHybridCLIPModule(nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class FlaxHybridCLIP(FlaxPreTrainedModel):
|
class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||||
config: HybridCLIPConfig
|
config_class = HybridCLIPConfig
|
||||||
module_class = FlaxHybridCLIPModule
|
module_class = FlaxHybridCLIPModule
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -304,6 +304,58 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
*model_args,
|
*model_args,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> FlaxPreTrainedModel:
|
) -> FlaxPreTrainedModel:
|
||||||
|
"""
|
||||||
|
Params:
|
||||||
|
text_model_name_or_path (:obj: `str`, `optional`):
|
||||||
|
Information necessary to initiate the text model. Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||||
|
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
|
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||||
|
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||||
|
|
||||||
|
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||||
|
Information necessary to initiate the vision model. Can be either:
|
||||||
|
|
||||||
|
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||||
|
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||||
|
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||||
|
- A path to a `directory` containing model weights saved using
|
||||||
|
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||||
|
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||||
|
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||||
|
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||||
|
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||||
|
|
||||||
|
model_args (remaining positional arguments, `optional`):
|
||||||
|
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||||
|
|
||||||
|
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||||
|
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||||
|
:obj:`output_attentions=True`).
|
||||||
|
|
||||||
|
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
||||||
|
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
||||||
|
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||||
|
|
||||||
|
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
>>> from transformers import FlaxHybridCLIP
|
||||||
|
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
|
||||||
|
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
|
||||||
|
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
|
||||||
|
>>> # saving model after fine-tuning
|
||||||
|
>>> model.save_pretrained("./bert-clip")
|
||||||
|
>>> # load fine-tuned model
|
||||||
|
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
|
||||||
|
"""
|
||||||
|
|
||||||
kwargs_text = {
|
kwargs_text = {
|
||||||
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
||||||
@@ -333,9 +385,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|||||||
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
||||||
kwargs_text["config"] = text_config
|
kwargs_text["config"] = text_config
|
||||||
|
|
||||||
text_model = FlaxAutoModel.from_pretrained(
|
text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
||||||
text_model_name_or_path, *model_args, from_pt=True, **kwargs_text
|
|
||||||
)
|
|
||||||
|
|
||||||
vision_model = kwargs_vision.pop("model", None)
|
vision_model = kwargs_vision.pop("model", None)
|
||||||
if vision_model is None:
|
if vision_model is None:
|
||||||
|
|||||||
@@ -87,6 +87,10 @@ class ModelArguments:
|
|||||||
"Don't set if you want to train a model from scratch."
|
"Don't set if you want to train a model from scratch."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
from_pt: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
|
||||||
|
)
|
||||||
config_name: Optional[str] = field(
|
config_name: Optional[str] = field(
|
||||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||||
)
|
)
|
||||||
@@ -332,6 +336,8 @@ def main():
|
|||||||
model_args.vision_model_name_or_path,
|
model_args.vision_model_name_or_path,
|
||||||
seed=training_args.seed,
|
seed=training_args.seed,
|
||||||
dtype=getattr(jnp, model_args.dtype),
|
dtype=getattr(jnp, model_args.dtype),
|
||||||
|
text_from_pt=model_args.from_pt,
|
||||||
|
vision_from_pt=model_args.from_pt,
|
||||||
)
|
)
|
||||||
config = model.config
|
config = model.config
|
||||||
# set seed for torch dataloaders
|
# set seed for torch dataloaders
|
||||||
|
|||||||
Reference in New Issue
Block a user