Add sdpa for Vivit (#33757)
* chore:add sdpa to vivit * fix:failing slow test_inference_interpolate_pos_encoding(failing on main branch too) * chore:fix nits * ci:fix repo consistency failure * chore:add info and benchmark to model doc * [run_slow] vivit * chore:revert interpolation test fix for new issue * [run_slow] vivit * [run_slow] vivit * [run_slow] vivit * chore:add fallback for output_attentions being True * [run_slow] vivit * style:make fixup * [run_slow] vivit
This commit is contained in:
@@ -65,6 +65,8 @@ class VivitModelTester:
|
||||
layer_norm_eps=1e-06,
|
||||
qkv_bias=True,
|
||||
scope=None,
|
||||
attn_implementation="eager",
|
||||
mask_ratio=0.5,
|
||||
):
|
||||
self.parent = parent
|
||||
self.batch_size = batch_size
|
||||
@@ -86,12 +88,15 @@ class VivitModelTester:
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.qkv_bias = qkv_bias
|
||||
self.scope = scope
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
self.seq_length = (
|
||||
(self.image_size // self.tubelet_size[2])
|
||||
* (self.image_size // self.tubelet_size[1])
|
||||
* (self.num_frames // self.tubelet_size[0])
|
||||
) + 1 # CLS token
|
||||
self.mask_ratio = mask_ratio
|
||||
self.num_masks = int(mask_ratio * self.seq_length)
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
pixel_values = floats_tensor(
|
||||
@@ -122,6 +127,7 @@ class VivitModelTester:
|
||||
initializer_range=self.initializer_range,
|
||||
layer_norm_eps=self.layer_norm_eps,
|
||||
qkv_bias=self.qkv_bias,
|
||||
attn_implementation=self.attn_implementation,
|
||||
)
|
||||
config.num_labels = self.num_labels
|
||||
return config
|
||||
|
||||
Reference in New Issue
Block a user