Add TFCLIPModel (#13967)
* Start the work for TFCLIPModel * Convert to TF code (TODO: loss + doc) * Clean up * Fix pooled_output for TFCLIPTextTransformer - using tf.gather_nd * assert -> raise error * Expose TFCLIPModel * Deal with dummy_inputs * Add tests * Fix all tests. TODO: manual check weight loading + add more comments * Fix pt tf equivalence test * fixes * update TFCLIPVisionEmbeddings's Conv2D * Fix loss + overwrite test_pt_tf_model_equivalence from common * Add a comment about the change about MainLayer in test_keras_save_load * Set return_loss=True in TFCLIPModelTester + make tests pass * overwrite test_pt_tf_model_equivalence from tf common * fix base_model_prefix * Fix examples * remove unused * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply review suggestions * change self.pre_layrnorm to self.pre_layernorm * apply more review suggestions * return attention probs before dropout (to align with PT) * fix weight init * fix * build doc * fix missing doc * fix for test Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -282,6 +282,8 @@ class TFModelTesterMixin:
|
||||
for module in (import_module(model_class.__module__),)
|
||||
for module_member_name in dir(module)
|
||||
if module_member_name.endswith("MainLayer")
|
||||
# This condition is required, since `modeling_tf_clip.py` has 3 classes whose names end with `MainLayer`.
|
||||
and module_member_name[: -len("MainLayer")] == model_class.__name__[: -len("Model")]
|
||||
for module_member in (getattr(module, module_member_name),)
|
||||
if isinstance(module_member, type)
|
||||
and tf.keras.layers.Layer in module_member.__bases__
|
||||
@@ -458,7 +460,7 @@ class TFModelTesterMixin:
|
||||
"input_ids": tf.keras.Input(batch_shape=(2, max_input), name="input_ids", dtype="int32"),
|
||||
}
|
||||
# TODO: A better way to handle vision models
|
||||
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification"]:
|
||||
elif model_class.__name__ in ["TFViTModel", "TFViTForImageClassification", "TFCLIPVisionModel"]:
|
||||
inputs = tf.keras.Input(
|
||||
batch_shape=(
|
||||
3,
|
||||
@@ -469,6 +471,20 @@ class TFModelTesterMixin:
|
||||
name="pixel_values",
|
||||
dtype="float32",
|
||||
)
|
||||
elif model_class.__name__ in ["TFCLIPModel"]:
|
||||
inputs = {
|
||||
"input_ids": tf.keras.Input(batch_shape=(3, max_input), name="input_ids", dtype="int32"),
|
||||
"pixel_values": tf.keras.Input(
|
||||
batch_shape=(
|
||||
3,
|
||||
self.model_tester.vision_model_tester.num_channels,
|
||||
self.model_tester.vision_model_tester.image_size,
|
||||
self.model_tester.vision_model_tester.image_size,
|
||||
),
|
||||
name="pixel_values",
|
||||
dtype="float32",
|
||||
),
|
||||
}
|
||||
elif model_class in get_values(TF_MODEL_FOR_MULTIPLE_CHOICE_MAPPING):
|
||||
inputs = tf.keras.Input(batch_shape=(4, 2, max_input), name="input_ids", dtype="int32")
|
||||
else:
|
||||
@@ -1244,6 +1260,13 @@ def ids_tensor(shape, vocab_size, rng=None, name=None, dtype=None):
|
||||
return output
|
||||
|
||||
|
||||
def random_attention_mask(shape, rng=None, name=None, dtype=None):
|
||||
attn_mask = ids_tensor(shape, vocab_size=2, rng=None, name=None, dtype=dtype)
|
||||
# make sure that at least one token is attended to for each batch
|
||||
attn_mask = tf.concat([tf.constant(value=1, shape=(shape[0], 1), dtype=dtype), attn_mask[:, 1:]], axis=1)
|
||||
return attn_mask
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None, dtype=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
|
||||
Reference in New Issue
Block a user