Fix gradient_checkpointing backward compatibility (#14408)
* Fix gradient_checkpointing backward compatibility * Remove needless line * make sure mask prob is big enough and length small enough * Fix tests Co-authored-by: patrickvonplaten <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -78,6 +78,8 @@ class Wav2Vec2ModelTester:
|
||||
layer_norm_eps=1e-5,
|
||||
hidden_act="gelu",
|
||||
initializer_range=0.02,
|
||||
mask_time_prob=0.5,
|
||||
mask_time_length=2,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
scope=None,
|
||||
@@ -105,6 +107,8 @@ class Wav2Vec2ModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.scope = scope
|
||||
|
||||
output_seq_length = self.seq_length
|
||||
@@ -131,6 +135,8 @@ class Wav2Vec2ModelTester:
|
||||
conv_stride=self.conv_stride,
|
||||
conv_kernel=self.conv_kernel,
|
||||
conv_bias=self.conv_bias,
|
||||
mask_time_prob=self.mask_time_prob,
|
||||
mask_time_length=self.mask_time_length,
|
||||
num_conv_pos_embeddings=self.num_conv_pos_embeddings,
|
||||
num_conv_pos_embedding_groups=self.num_conv_pos_embedding_groups,
|
||||
num_hidden_layers=self.num_hidden_layers,
|
||||
|
||||
Reference in New Issue
Block a user