Add TF port of BLIP (#22090)
* Initial commit * more stash commit * Yet another stash commit * yet more stash commit * Mostly working except for docs / repo consistency * Stop importing model list from torch file * Add TF BLIP models to docs * Add auto classes * Move get_text_features and get_image_features * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update tests/models/blip/test_modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update tests/models/blip/test_modeling_tf_blip_text.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Use channels_last convolutions in TF (better performance + compatibility) * Remove _shape function * Move multi-line statement to one line in PT + TF * Specify tf.keras.layers instead of importing from it * Remove test_gradient_checkpointing and empty test_training methods * move some multi-line statements to one line * Update docstring for generate * Remove pruned heads set * Remove self.seq_len_dim * Fixed issues with loss computation, should resolve some tests. Also ensured that the PT version follows the config for output_attentions and output_hidden_states * ensure original model follows config in more cases * Skip the same cross-attention tests in the PT tests - didn't realize we did it twice! * Add training args throughout the models and layers * make fixup * Fix docstring for inputs_embeds * Add docstring for is_decoder * Add docstrings to text models * Remove redundant computation * Add unpack_inputs / keras_serializable * Add modeling_tf_blip to doctests * Add config classes for keras serialization * Changes to allow model porting with pt-to-tf * Quick fix to decoder head and test tweaks * Revert an issue with masking the embeddings outputs * Allow missing keys in some equivalence tests (for unused layers) * Add tf-pt equivalence tests back in * Update src/transformers/models/blip/modeling_tf_blip.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/models/blip/modeling_tf_blip_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * make fixup * Refactor invert_attention_mask out into tf_utils * Re-enable cross-tests on the PT side too --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1984,7 +1984,7 @@ class ModelTesterMixin:
|
||||
self.check_pt_tf_outputs(tf_outputs, pt_outputs, type(pt_model))
|
||||
|
||||
@is_pt_tf_cross_test
|
||||
def test_pt_tf_model_equivalence(self):
|
||||
def test_pt_tf_model_equivalence(self, allow_missing_keys=False):
|
||||
import transformers
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
@@ -2036,8 +2036,12 @@ class ModelTesterMixin:
|
||||
# Check we can load pt model in tf and vice-versa with model => model functions
|
||||
# Here requires `tf_inputs_dict` to build `tf_model`
|
||||
tf_inputs_dict = self.prepare_tf_inputs_from_pt_inputs(pt_inputs_dict)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(tf_model, pt_model, tf_inputs=tf_inputs_dict)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(pt_model, tf_model)
|
||||
tf_model = transformers.load_pytorch_model_in_tf2_model(
|
||||
tf_model, pt_model, tf_inputs=tf_inputs_dict, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
pt_model = transformers.load_tf2_model_in_pytorch_model(
|
||||
pt_model, tf_model, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
@@ -2049,11 +2053,15 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pt_checkpoint_path = os.path.join(tmpdirname, "pt_model.bin")
|
||||
torch.save(pt_model.state_dict(), pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(tf_model, pt_checkpoint_path)
|
||||
tf_model = transformers.load_pytorch_checkpoint_in_tf2_model(
|
||||
tf_model, pt_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
tf_checkpoint_path = os.path.join(tmpdirname, "tf_model.h5")
|
||||
tf_model.save_weights(tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(pt_model, tf_checkpoint_path)
|
||||
pt_model = transformers.load_tf2_checkpoint_in_pytorch_model(
|
||||
pt_model, tf_checkpoint_path, allow_missing_keys=allow_missing_keys
|
||||
)
|
||||
|
||||
# Original test: check without `labels`
|
||||
self.check_pt_tf_models(tf_model, pt_model, pt_inputs_dict)
|
||||
|
||||
Reference in New Issue
Block a user