[Flax] Fix hybrid clip (#12519)
* fix saving and loading * update readme
This commit is contained in:
@@ -68,6 +68,34 @@ export MODEL_DIR="./clip-roberta-base
|
||||
ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py
|
||||
```
|
||||
|
||||
## How to use the `FlaxHybridCLIP` model:
|
||||
|
||||
The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder.
|
||||
Here is an example of how to load the model using pre-trained text and vision models.
|
||||
|
||||
```python
|
||||
from modeling_hybrid_clip import FlaxHybridCLIP
|
||||
|
||||
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32")
|
||||
|
||||
# save the model
|
||||
model.save_pretrained("bert-clip")
|
||||
|
||||
# load the saved model
|
||||
model = FlaxHybridCLIP.from_pretrained("bert-clip")
|
||||
```
|
||||
|
||||
If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model
|
||||
PyTorch checkpoints convert them to flax and load the model.
|
||||
|
||||
```python
|
||||
model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True)
|
||||
```
|
||||
|
||||
This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly
|
||||
initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also
|
||||
loaded using the pre-trained weights.
|
||||
|
||||
## Prepare the dataset
|
||||
|
||||
We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search.
|
||||
@@ -124,7 +152,7 @@ with open("coco_dataset/valid_dataset.json", "w") as f:
|
||||
Next we can run the example script to train the model:
|
||||
|
||||
```bash
|
||||
python run_clip.py \
|
||||
python run_hybrid_clip.py \
|
||||
--output_dir ${MODEL_DIR} \
|
||||
--text_model_name_or_path="roberta-base" \
|
||||
--vision_model_name_or_path="openai/clip-vit-base-patch32" \
|
||||
|
||||
@@ -25,31 +25,58 @@ class HybridCLIPConfig(PretrainedConfig):
|
||||
Dimentionality of text and vision projection layers.
|
||||
kwargs (`optional`):
|
||||
Dictionary of keyword arguments.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP
|
||||
|
||||
>>> # Initializing a BERT and CLIP configuration
|
||||
>>> config_text = BertConfig()
|
||||
>>> config_vision = CLIPConfig()
|
||||
|
||||
>>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512)
|
||||
|
||||
>>> # Initializing a BERT and CLIPVision model
|
||||
>>> model = EncoderDecoderModel(config=config)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> config_text = model.config.text_config
|
||||
>>> config_vision = model.config.vision_config
|
||||
|
||||
>>> # Saving the model, including its configuration
|
||||
>>> model.save_pretrained('my-model')
|
||||
|
||||
>>> # loading model and config from pretrained folder
|
||||
>>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model')
|
||||
>>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config)
|
||||
"""
|
||||
|
||||
model_type = "hybrid-clip"
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, text_config_dict, vision_config_dict, projection_dim=512, **kwargs):
|
||||
def __init__(self, projection_dim=512, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if text_config_dict is None:
|
||||
raise ValueError("`text_config_dict` can not be `None`.")
|
||||
if "text_config" not in kwargs:
|
||||
raise ValueError("`text_config` can not be `None`.")
|
||||
|
||||
if vision_config_dict is None:
|
||||
raise ValueError("`vision_config_dict` can not be `None`.")
|
||||
if "vision_config" not in kwargs:
|
||||
raise ValueError("`vision_config` can not be `None`.")
|
||||
|
||||
text_model_type = text_config_dict.pop("model_type")
|
||||
vision_model_type = vision_config_dict.pop("model_type")
|
||||
text_config = kwargs.pop("text_config")
|
||||
vision_config = kwargs.pop("vision_config")
|
||||
|
||||
text_model_type = text_config.pop("model_type")
|
||||
vision_model_type = vision_config.pop("model_type")
|
||||
|
||||
from transformers import AutoConfig
|
||||
|
||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config_dict)
|
||||
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
||||
|
||||
if vision_model_type == "clip":
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict).vision_config
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config
|
||||
else:
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict)
|
||||
self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config)
|
||||
|
||||
self.projection_dim = projection_dim
|
||||
self.initializer_factor = 1.0
|
||||
@@ -64,7 +91,7 @@ class HybridCLIPConfig(PretrainedConfig):
|
||||
:class:`HybridCLIPConfig`: An instance of a configuration object
|
||||
"""
|
||||
|
||||
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs)
|
||||
return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs)
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
|
||||
@@ -123,7 +123,7 @@ class FlaxHybridCLIPModule(nn.Module):
|
||||
|
||||
|
||||
class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
config: HybridCLIPConfig
|
||||
config_class = HybridCLIPConfig
|
||||
module_class = FlaxHybridCLIPModule
|
||||
|
||||
def __init__(
|
||||
@@ -304,6 +304,58 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
*model_args,
|
||||
**kwargs,
|
||||
) -> FlaxPreTrainedModel:
|
||||
"""
|
||||
Params:
|
||||
text_model_name_or_path (:obj: `str`, `optional`):
|
||||
Information necessary to initiate the text model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`):
|
||||
Information necessary to initiate the vision model. Can be either:
|
||||
|
||||
- A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under
|
||||
a user or organization name, like ``dbmdz/bert-base-german-cased``.
|
||||
- A path to a `directory` containing model weights saved using
|
||||
:func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``.
|
||||
- A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In
|
||||
this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided
|
||||
as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in
|
||||
a Flax model using the provided conversion scripts and loading the Flax model afterwards.
|
||||
|
||||
model_args (remaining positional arguments, `optional`):
|
||||
All remaning positional arguments will be passed to the underlying model's ``__init__`` method.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, `optional`):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
:obj:`output_attentions=True`).
|
||||
|
||||
- To update the text configuration, use the prefix `text_` for each configuration parameter.
|
||||
- To update the vision configuration, use the prefix `vision_` for each configuration parameter.
|
||||
- To update the parent model configuration, do not use a prefix for each configuration parameter.
|
||||
|
||||
Behaves differently depending on whether a :obj:`config` is provided or automatically loaded.
|
||||
|
||||
Example::
|
||||
|
||||
>>> from transformers import FlaxHybridCLIP
|
||||
>>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized.
|
||||
>>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights
|
||||
>>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32')
|
||||
>>> # saving model after fine-tuning
|
||||
>>> model.save_pretrained("./bert-clip")
|
||||
>>> # load fine-tuned model
|
||||
>>> model = FlaxHybridCLIP.from_pretrained("./bert-clip")
|
||||
"""
|
||||
|
||||
kwargs_text = {
|
||||
argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_")
|
||||
@@ -333,9 +385,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
||||
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
||||
kwargs_text["config"] = text_config
|
||||
|
||||
text_model = FlaxAutoModel.from_pretrained(
|
||||
text_model_name_or_path, *model_args, from_pt=True, **kwargs_text
|
||||
)
|
||||
text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text)
|
||||
|
||||
vision_model = kwargs_vision.pop("model", None)
|
||||
if vision_model is None:
|
||||
|
||||
@@ -87,6 +87,10 @@ class ModelArguments:
|
||||
"Don't set if you want to train a model from scratch."
|
||||
},
|
||||
)
|
||||
from_pt: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."},
|
||||
)
|
||||
config_name: Optional[str] = field(
|
||||
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
|
||||
)
|
||||
@@ -332,6 +336,8 @@ def main():
|
||||
model_args.vision_model_name_or_path,
|
||||
seed=training_args.seed,
|
||||
dtype=getattr(jnp, model_args.dtype),
|
||||
text_from_pt=model_args.from_pt,
|
||||
vision_from_pt=model_args.from_pt,
|
||||
)
|
||||
config = model.config
|
||||
# set seed for torch dataloaders
|
||||
|
||||
Reference in New Issue
Block a user