Add TFViTModel (#13778)
* Start the work for TFViTModel * Convert to TF code - need to check in the follow up commits * Clean up model code * Expose TFViTModel * make style * make quality * Add test * make style & quality * Fix some imports * fix wrong usage - *kwargs => ** kwargs * Fix Conv2D weight loading (PT->TF) issue * Add tests for images with different sizes + fix model * Fix some common tests for TFViTModel * Use inputs instead of input_ids in test_compile_tf_model * Add a comment about transpose and Conv2D in convert_tf_weight_name_to_pt_weight_name * Avoid transpose in TFViT call * Fix Conv2D issue in load_tf2_weights_in_pytorch_model * Use tf.keras.layers.Conv2D instead of tf.nn.conv2d * Using simpler heuristic to detect Conv2D layer * Change convert_tf_weight_name_to_pt_weight_name to return TransposeType * Check tf_weight_shape is not None before using it * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * fix missing comma * fix input dtype Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
@@ -1476,6 +1476,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tensor
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
@@ -1525,6 +1527,8 @@ class ModelTesterMixin:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor, dtype=tf.int32)
|
||||
elif key == "input_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
elif key == "pixel_values":
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.float32)
|
||||
else:
|
||||
tf_inputs_dict[key] = tf.convert_to_tensor(tensor.numpy(), dtype=tf.int32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user