[Wav2Vec2] SpecAugment Fast (#11764)

* first try

* finish
This commit is contained in:
Patrick von Platen
2021-05-25 13:59:52 +01:00
committed by GitHub
parent f086652b16
commit 7630c11f32
2 changed files with 53 additions and 82 deletions

View File

@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5
mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self):
batch_size = 4
sequence_length = 60
mask_prob = 0.5
mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
)
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
),
)
@require_torch
@slow