From f5b0c1ecf077e4b79c0e9d556502bbc19db74144 Mon Sep 17 00:00:00 2001 From: Suraj Patil Date: Tue, 6 Jul 2021 11:12:47 +0530 Subject: [PATCH] [Flax] Fix hybrid clip (#12519) * fix saving and loading * update readme --- .../jax-projects/hybrid_clip/README.md | 30 +++++++++- .../hybrid_clip/configuration_hybrid_clip.py | 49 ++++++++++++---- .../hybrid_clip/modeling_hybrid_clip.py | 58 +++++++++++++++++-- .../hybrid_clip/run_hybrid_clip.py | 6 ++ 4 files changed, 127 insertions(+), 16 deletions(-) diff --git a/examples/research_projects/jax-projects/hybrid_clip/README.md b/examples/research_projects/jax-projects/hybrid_clip/README.md index 6ab7a58876..6137731660 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/README.md +++ b/examples/research_projects/jax-projects/hybrid_clip/README.md @@ -68,6 +68,34 @@ export MODEL_DIR="./clip-roberta-base ln -s ~/transformers/examples/flax/summarization/run_hybrid_clip.py run_hybrid_clip.py ``` +## How to use the `FlaxHybridCLIP` model: + +The `FlaxHybridCLIP` class let's you load any text and vision encoder model to create a dual encoder. +Here is an example of how to load the model using pre-trained text and vision models. + +```python +from modeling_hybrid_clip import FlaxHybridCLIP + +model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32") + +# save the model +model.save_pretrained("bert-clip") + +# load the saved model +model = FlaxHybridCLIP.from_pretrained("bert-clip") +``` + +If the checkpoints are in PyTorch then one could pass `text_from_pt=True` and `vision_from_pt=True`. This will load the model +PyTorch checkpoints convert them to flax and load the model. + +```python +model = FlaxHybridCLIP.from_text_vision_pretrained("bert-base-uncased", "openai/clip-vit-base-patch32", text_from_pt=True, vision_from_pt=True) +``` + +This loads both the text and vision encoders using pre-trained weights, the projection layers are randomly +initialized except for CLIP's vision model. If you use CLIP to initialize the vision model then the vision projection weights are also +loaded using the pre-trained weights. + ## Prepare the dataset We will use the MS-COCO dataset to train our dual encoder model. MS-COCO contains over 82,000 images, each of which has at least 5 different caption annotations. The dataset is usually used for image captioning tasks, but we can repurpose the image-caption pairs to train our dual encoder model for image search. @@ -124,7 +152,7 @@ with open("coco_dataset/valid_dataset.json", "w") as f: Next we can run the example script to train the model: ```bash -python run_clip.py \ +python run_hybrid_clip.py \ --output_dir ${MODEL_DIR} \ --text_model_name_or_path="roberta-base" \ --vision_model_name_or_path="openai/clip-vit-base-patch32" \ diff --git a/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py index 00e30c16a0..1a2c51f554 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/configuration_hybrid_clip.py @@ -25,31 +25,58 @@ class HybridCLIPConfig(PretrainedConfig): Dimentionality of text and vision projection layers. kwargs (`optional`): Dictionary of keyword arguments. + + Examples:: + + >>> from transformers import BertConfig, CLIPConfig, HybridCLIPConfig, FlaxHybridCLIP + + >>> # Initializing a BERT and CLIP configuration + >>> config_text = BertConfig() + >>> config_vision = CLIPConfig() + + >>> config = HybridCLIPConfig.from_text_vision_configs(config_text, config_vision, projection_dim=512) + + >>> # Initializing a BERT and CLIPVision model + >>> model = EncoderDecoderModel(config=config) + + >>> # Accessing the model configuration + >>> config_text = model.config.text_config + >>> config_vision = model.config.vision_config + + >>> # Saving the model, including its configuration + >>> model.save_pretrained('my-model') + + >>> # loading model and config from pretrained folder + >>> encoder_decoder_config = HybridCLIPConfig.from_pretrained('my-model') + >>> model = FlaxHybridCLIP.from_pretrained('my-model', config=encoder_decoder_config) """ model_type = "hybrid-clip" is_composition = True - def __init__(self, text_config_dict, vision_config_dict, projection_dim=512, **kwargs): + def __init__(self, projection_dim=512, **kwargs): super().__init__(**kwargs) - if text_config_dict is None: - raise ValueError("`text_config_dict` can not be `None`.") + if "text_config" not in kwargs: + raise ValueError("`text_config` can not be `None`.") - if vision_config_dict is None: - raise ValueError("`vision_config_dict` can not be `None`.") + if "vision_config" not in kwargs: + raise ValueError("`vision_config` can not be `None`.") - text_model_type = text_config_dict.pop("model_type") - vision_model_type = vision_config_dict.pop("model_type") + text_config = kwargs.pop("text_config") + vision_config = kwargs.pop("vision_config") + + text_model_type = text_config.pop("model_type") + vision_model_type = vision_config.pop("model_type") from transformers import AutoConfig - self.text_config = AutoConfig.for_model(text_model_type, **text_config_dict) + self.text_config = AutoConfig.for_model(text_model_type, **text_config) if vision_model_type == "clip": - self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict).vision_config + self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config).vision_config else: - self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config_dict) + self.vision_config = AutoConfig.for_model(vision_model_type, **vision_config) self.projection_dim = projection_dim self.initializer_factor = 1.0 @@ -64,7 +91,7 @@ class HybridCLIPConfig(PretrainedConfig): :class:`HybridCLIPConfig`: An instance of a configuration object """ - return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) + return cls(text_config=text_config.to_dict(), vision_config=vision_config.to_dict(), **kwargs) def to_dict(self): """ diff --git a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py index 8ea7831d49..af4786eaf1 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/modeling_hybrid_clip.py @@ -123,7 +123,7 @@ class FlaxHybridCLIPModule(nn.Module): class FlaxHybridCLIP(FlaxPreTrainedModel): - config: HybridCLIPConfig + config_class = HybridCLIPConfig module_class = FlaxHybridCLIPModule def __init__( @@ -304,6 +304,58 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): *model_args, **kwargs, ) -> FlaxPreTrainedModel: + """ + Params: + text_model_name_or_path (:obj: `str`, `optional`): + Information necessary to initiate the text model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in + a Flax model using the provided conversion scripts and loading the Flax model afterwards. + + vision_model_name_or_path (:obj: `str`, `optional`, defaults to `None`): + Information necessary to initiate the vision model. Can be either: + + - A string, the `model id` of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like ``bert-base-uncased``, or namespaced under + a user or organization name, like ``dbmdz/bert-base-german-cased``. + - A path to a `directory` containing model weights saved using + :func:`~transformers.FlaxPreTrainedModel.save_pretrained`, e.g., ``./my_model_directory/``. + - A path or url to a `PyTorch checkpoint folder` (e.g, ``./pt_model``). In + this case, ``from_pt`` should be set to :obj:`True` and a configuration object should be provided + as ``config`` argument. This loading path is slower than converting the PyTorch checkpoint in + a Flax model using the provided conversion scripts and loading the Flax model afterwards. + + model_args (remaining positional arguments, `optional`): + All remaning positional arguments will be passed to the underlying model's ``__init__`` method. + + kwargs (remaining dictionary of keyword arguments, `optional`): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + :obj:`output_attentions=True`). + + - To update the text configuration, use the prefix `text_` for each configuration parameter. + - To update the vision configuration, use the prefix `vision_` for each configuration parameter. + - To update the parent model configuration, do not use a prefix for each configuration parameter. + + Behaves differently depending on whether a :obj:`config` is provided or automatically loaded. + + Example:: + + >>> from transformers import FlaxHybridCLIP + >>> # initialize a model from pretrained BERT and CLIP models. Note that the projection layers will be randomly initialized. + >>> # If using CLIP's vision model the vision projection layer will be initialized using pre-trained weights + >>> model = FlaxHybridCLIP.from_text_vision_pretrained('bert-base-uncased', 'openai/clip-vit-base-patch32') + >>> # saving model after fine-tuning + >>> model.save_pretrained("./bert-clip") + >>> # load fine-tuned model + >>> model = FlaxHybridCLIP.from_pretrained("./bert-clip") + """ kwargs_text = { argument[len("text_") :]: value for argument, value in kwargs.items() if argument.startswith("text_") @@ -333,9 +385,7 @@ class FlaxHybridCLIP(FlaxPreTrainedModel): text_config = AutoConfig.from_pretrained(text_model_name_or_path) kwargs_text["config"] = text_config - text_model = FlaxAutoModel.from_pretrained( - text_model_name_or_path, *model_args, from_pt=True, **kwargs_text - ) + text_model = FlaxAutoModel.from_pretrained(text_model_name_or_path, *model_args, **kwargs_text) vision_model = kwargs_vision.pop("model", None) if vision_model is None: diff --git a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py index 8c23e5f819..b9200e0b2b 100644 --- a/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py +++ b/examples/research_projects/jax-projects/hybrid_clip/run_hybrid_clip.py @@ -87,6 +87,10 @@ class ModelArguments: "Don't set if you want to train a model from scratch." }, ) + from_pt: bool = field( + default=True, + metadata={"help": "whether to load the text and vision model using PyTorch checkpoints."}, + ) config_name: Optional[str] = field( default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} ) @@ -332,6 +336,8 @@ def main(): model_args.vision_model_name_or_path, seed=training_args.seed, dtype=getattr(jnp, model_args.dtype), + text_from_pt=model_args.from_pt, + vision_from_pt=model_args.from_pt, ) config = model.config # set seed for torch dataloaders