From 5f3ea66bc0c27ad2a8761fdf8489cf7d72257b93 Mon Sep 17 00:00:00 2001 From: Matt Date: Tue, 4 Apr 2023 16:05:22 +0100 Subject: [PATCH] Add TF port of BLIP (#22090) * Initial commit * more stash commit * Yet another stash commit * yet more stash commit * Mostly working except for docs / repo consistency * Stop importing model list from torch file * Add TF BLIP models to docs * Add auto classes * Move get_text_features and get_image_features * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Joao Gante * Update tests/models/blip/test_modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Joao Gante * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Use channels_last convolutions in TF (better performance + compatibility) * Remove _shape function * Move multi-line statement to one line in PT + TF * Specify tf.keras.layers instead of importing from it * Remove test_gradient_checkpointing and empty test_training methods * move some multi-line statements to one line * Update docstring for generate * Remove pruned heads set * Remove self.seq_len_dim * Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states * ensure original model follows config in more cases * Skip the same cross-attention tests in the PT tests - didn't realize we did it twice! * Add training args throughout the models and layers * make fixup * Fix docstring for inputs_embeds * Add docstring for is_decoder * Add docstrings to text models * Remove redundant computation * Add unpack_inputs / keras_serializable * Add modeling_tf_blip to doctests * Add config classes for keras serialization * Changes to allow model porting with pt-to-tf * Quick fix to decoder head and test tweaks * Revert an issue with masking the embeddings outputs * Allow missing keys in some equivalence tests (for unused layers) * Add tf-pt equivalence tests back in * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make fixup * Refactor invert_attention_mask out into tf_utils * Re-enable cross-tests on the PT side too --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- docs/source/en/index.mdx | 2 +- docs/source/en/model_doc/blip.mdx | 40 +- src/transformers/__init__.py | 22 + src/transformers/commands/pt_to_tf.py | 20 +- src/transformers/modeling_tf_utils.py | 33 + .../models/auto/modeling_tf_auto.py | 2 + src/transformers/models/blip/__init__.py | 42 +- src/transformers/models/blip/modeling_blip.py | 40 +- .../models/blip/modeling_tf_blip.py | 1753 +++++++++++++++++ .../models/blip/modeling_tf_blip_text.py | 1013 ++++++++++ .../models/blip_2/modeling_blip_2.py | 4 +- src/transformers/tf_utils.py | 28 + src/transformers/utils/dummy_tf_objects.py | 52 + tests/models/blip/test_modeling_blip.py | 6 + tests/models/blip/test_modeling_blip_text.py | 3 + tests/models/blip/test_modeling_tf_blip.py | 824 ++++++++ .../models/blip/test_modeling_tf_blip_text.py | 170 ++ tests/test_modeling_common.py | 18 +- tests/test_modeling_tf_common.py | 24 +- utils/check_repo.py | 7 + utils/documentation_tests.txt | 1 + 21 files changed, 4059 insertions(+), 45 deletions(-) create mode 100644 src/transformers/models/blip/modeling_tf_blip.py create mode 100644 src/transformers/models/blip/modeling_tf_blip_text.py create mode 100644 tests/models/blip/test_modeling_tf_blip.py create mode 100644 tests/models/blip/test_modeling_tf_blip_text.py diff --git a/docs/source/en/index.mdx b/docs/source/en/index.mdx index 216425b29f..8c73cc98dc 100644 --- a/docs/source/en/index.mdx +++ b/docs/source/en/index.mdx @@ -269,7 +269,7 @@ Flax), PyTorch, and/or TensorFlow. | BiT | ❌ | ❌ | ✅ | ❌ | ❌ | | Blenderbot | ✅ | ✅ | ✅ | ✅ | ✅ | | BlenderbotSmall | ✅ | ✅ | ✅ | ✅ | ✅ | -| BLIP | ❌ | ❌ | ✅ | ❌ | ❌ | +| BLIP | ❌ | ❌ | ✅ | ✅ | ❌ | | BLIP-2 | ❌ | ❌ | ✅ | ❌ | ❌ | | BLOOM | ❌ | ✅ | ✅ | ❌ | ❌ | | BridgeTower | ❌ | ❌ | ✅ | ❌ | ❌ | diff --git a/docs/source/en/model_doc/blip.mdx b/docs/source/en/model_doc/blip.mdx index 42116f4869..12cc26f418 100644 --- a/docs/source/en/model_doc/blip.mdx +++ b/docs/source/en/model_doc/blip.mdx @@ -1,4 +1,4 @@ -