* fix #14524 (IndexError when mask prob is too low) * fix formatting * correct documentation, add option for setting min_num_masks * change the semantic meaning of `mask_prob` in _compute_mask_indices With this commit the meaing of `mask_prob` actually adhered to the probability for each vector to be the start of a masked span of length. * fix check_copies test * fix documentation to semantic meaning of `upper bound of overall masking percentage`, revert changes to _compute_mask_indices * fix typo
This commit is contained in:
@@ -854,6 +854,36 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_low_prob(self):
|
||||
# with these settings num_masked_spans=0.5, which means probabilistic rounding
|
||||
# ensures that in 5 out of 10 method calls, num_masked_spans=0, and in
|
||||
# the other 5 out of 10, cases num_masked_spans=1
|
||||
n_trials = 100
|
||||
batch_size = 4
|
||||
sequence_length = 100
|
||||
mask_prob = 0.05
|
||||
mask_length = 10
|
||||
|
||||
count_dimensions_masked = 0
|
||||
count_dimensions_not_masked = 0
|
||||
|
||||
for _ in range(n_trials):
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = torch.from_numpy(mask).to(torch_device)
|
||||
|
||||
num_masks = torch.sum(mask).item()
|
||||
|
||||
if num_masks > 0:
|
||||
count_dimensions_masked += 1
|
||||
else:
|
||||
count_dimensions_not_masked += 1
|
||||
|
||||
# as we test for at least 10 masked dimension and at least
|
||||
# 10 non-masked dimension, this test could fail with probability:
|
||||
# P(100 coin flips, at most 9 heads) = 1.66e-18
|
||||
self.assertGreater(count_dimensions_masked, int(n_trials * 0.1))
|
||||
self.assertGreater(count_dimensions_not_masked, int(n_trials * 0.1))
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
|
||||
Reference in New Issue
Block a user