Add Flax image captioning example (#14864)
* add image captioning example * update README * fix style & quality * simplify * apply review suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply review suggestions * add comments about using np instead jax array * remove unused lines * add model creation script * only support from_pretrained * fix style * fix * not use cache_dir when creating model * fix tokenizer creation * update README * fix quality * apply suggestion * simplify some blocks * Update examples/flax/image-captioning/README.md * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestion Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
68
examples/flax/image-captioning/README.md
Normal file
68
examples/flax/image-captioning/README.md
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Image Captioning (vision-encoder-text-decoder model) training example
|
||||||
|
|
||||||
|
The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning
|
||||||
|
using the JAX/Flax backend, leveraging 🤗 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel).
|
||||||
|
|
||||||
|
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
|
||||||
|
Models written in JAX/Flax are **immutable** and updated in a purely functional
|
||||||
|
way which enables simple and efficient model parallelism.
|
||||||
|
|
||||||
|
`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 🤗 Datasets
|
||||||
|
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.
|
||||||
|
|
||||||
|
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.
|
||||||
|
|
||||||
|
### Download COCO dataset (2017)
|
||||||
|
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the
|
||||||
|
COCO dataset before training.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir data
|
||||||
|
cd data
|
||||||
|
wget http://images.cocodataset.org/zips/train2017.zip
|
||||||
|
wget http://images.cocodataset.org/zips/val2017.zip
|
||||||
|
wget http://images.cocodataset.org/zips/test2017.zip
|
||||||
|
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
|
||||||
|
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
|
||||||
|
cd ..
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create a model from a vision encoder model and a text decoder model
|
||||||
|
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 create_model_from_encoder_decoder_models.py \
|
||||||
|
--output_dir model \
|
||||||
|
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \
|
||||||
|
--decoder_model_name_or_path gpt2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Train the model
|
||||||
|
Finally, we can run the example script to train the model:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python3 run_image_captioning_flax.py \
|
||||||
|
--output_dir ./image-captioning-training-results \
|
||||||
|
--model_name_or_path model \
|
||||||
|
--dataset_name ydshieh/coco_dataset_script \
|
||||||
|
--dataset_config_name=2017 \
|
||||||
|
--data_dir $PWD/data \
|
||||||
|
--image_column image_path \
|
||||||
|
--caption_column caption \
|
||||||
|
--do_train --do_eval --predict_with_generate \
|
||||||
|
--num_train_epochs 1 \
|
||||||
|
--eval_steps 500 \
|
||||||
|
--learning_rate 3e-5 --warmup_steps 0 \
|
||||||
|
--per_device_train_batch_size 32 \
|
||||||
|
--per_device_eval_batch_size 32 \
|
||||||
|
--overwrite_output_dir \
|
||||||
|
--max_target_length 32 \
|
||||||
|
--num_beams 8 \
|
||||||
|
--preprocessing_num_workers 16 \
|
||||||
|
--logging_steps 10 \
|
||||||
|
--block_size 16384 \
|
||||||
|
--push_to_hub
|
||||||
|
```
|
||||||
|
|
||||||
|
This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively
|
||||||
|
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard).
|
||||||
@@ -0,0 +1,118 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 The HuggingFace Team All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
"""
|
||||||
|
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
|
||||||
|
|
||||||
|
The cross-attention will be randomly initialized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from transformers import (
|
||||||
|
AutoConfig,
|
||||||
|
AutoFeatureExtractor,
|
||||||
|
AutoTokenizer,
|
||||||
|
FlaxVisionEncoderDecoderModel,
|
||||||
|
HfArgumentParser,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelArguments:
|
||||||
|
"""
|
||||||
|
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
|
||||||
|
"""
|
||||||
|
|
||||||
|
output_dir: str = field(
|
||||||
|
metadata={"help": "The output directory where the model will be written."},
|
||||||
|
)
|
||||||
|
encoder_model_name_or_path: str = field(
|
||||||
|
metadata={
|
||||||
|
"help": "The encoder model checkpoint for weights initialization."
|
||||||
|
"Don't set if you want to train an encoder model from scratch."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
decoder_model_name_or_path: str = field(
|
||||||
|
metadata={
|
||||||
|
"help": "The decoder model checkpoint for weights initialization."
|
||||||
|
"Don't set if you want to train a decoder model from scratch."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
encoder_config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
|
||||||
|
)
|
||||||
|
decoder_config_name: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = HfArgumentParser((ModelArguments,))
|
||||||
|
(model_args,) = parser.parse_args_into_dataclasses()
|
||||||
|
|
||||||
|
# Load pretrained model and tokenizer
|
||||||
|
|
||||||
|
# Use explicit specified encoder config
|
||||||
|
if model_args.encoder_config_name:
|
||||||
|
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
|
||||||
|
# Use pretrained encoder model's config
|
||||||
|
else:
|
||||||
|
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)
|
||||||
|
|
||||||
|
# Use explicit specified decoder config
|
||||||
|
if model_args.decoder_config_name:
|
||||||
|
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
|
||||||
|
# Use pretrained decoder model's config
|
||||||
|
else:
|
||||||
|
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)
|
||||||
|
|
||||||
|
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
|
||||||
|
decoder_config.is_decoder = True
|
||||||
|
decoder_config.add_cross_attention = True
|
||||||
|
|
||||||
|
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
|
||||||
|
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
|
||||||
|
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
|
||||||
|
encoder_config=encoder_config,
|
||||||
|
decoder_config=decoder_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
|
||||||
|
decoder_start_token_id = decoder_config.decoder_start_token_id
|
||||||
|
pad_token_id = decoder_config.pad_token_id
|
||||||
|
if decoder_start_token_id is None:
|
||||||
|
decoder_start_token_id = decoder_config.bos_token_id
|
||||||
|
if pad_token_id is None:
|
||||||
|
pad_token_id = decoder_config.eos_token_id
|
||||||
|
|
||||||
|
# This is necessary to make Flax's generate() work
|
||||||
|
model.config.eos_token_id = decoder_config.eos_token_id
|
||||||
|
model.config.decoder_start_token_id = decoder_start_token_id
|
||||||
|
model.config.pad_token_id = pad_token_id
|
||||||
|
|
||||||
|
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
|
||||||
|
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)
|
||||||
|
|
||||||
|
model.save_pretrained(model_args.output_dir)
|
||||||
|
feature_extractor.save_pretrained(model_args.output_dir)
|
||||||
|
tokenizer.save_pretrained(model_args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
1202
examples/flax/image-captioning/run_image_captioning_flax.py
Normal file
1202
examples/flax/image-captioning/run_image_captioning_flax.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user