[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)

* fix_torch_device_generate_test

* remove @

* start adding tests

* correct wav2vec2 pretraining

* up

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-07-15 21:40:25 +01:00
committed by GitHub
parent 5f2791c7c1
commit 2e9fb13fb1
7 changed files with 98 additions and 5 deletions

View File

@@ -245,6 +245,24 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_perplexity(self):
probs = np.arange(100).reshape(2, 5, 10) / 100

View File

@@ -580,6 +580,24 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
def test_compute_mask_indices_attn_mask_overlap(self):
batch_size = 4
sequence_length = 80
mask_prob = 0.5
mask_length = 4
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[:2, sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
)
for batch_sum in mask.sum(axis=-1):
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
def test_compute_perplexity(self):
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100