Update modeling_wav2vec2.py (#15423)

* Update modeling_wav2vec2.py

With very tiny sound files (less than 0.1 seconds) the num_masked_span can be too long. The issue is described in issue #15366 and discussed with @patrickvonplaten.

* correct errors with mask time indices

* remove bogus file

* make fix-copies

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
peregilk
2022-01-31 21:22:11 +01:00
committed by GitHub
parent d984b10335
commit 125a2882b4
8 changed files with 129 additions and 14 deletions

View File

@@ -990,6 +990,23 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_mask_indices_short_audio(self):
batch_size = 4
sequence_length = 100
mask_prob = 0.05
mask_length = 10
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
# force one example to be heavily padded
attention_mask[0, 5:] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask, min_masks=2
)
# make sure that non-padded examples cannot be padded
self.assertFalse(mask[0][attention_mask[0].to(torch.bool).cpu()].any())
def test_compute_perplexity(self):
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100