[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)
* fix_torch_device_generate_test * remove @ * start adding tests * correct wav2vec2 pretraining * up * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
5f2791c7c1
commit
2e9fb13fb1
@@ -245,6 +245,24 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_mask_indices_attn_mask_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
|
||||
attention_mask[:2, sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = np.arange(100).reshape(2, 5, 10) / 100
|
||||
|
||||
|
||||
@@ -580,6 +580,24 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
def test_compute_mask_indices_attn_mask_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 80
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
attention_mask[:2, sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||
|
||||
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
||||
|
||||
def test_compute_perplexity(self):
|
||||
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
|
||||
|
||||
|
||||
Reference in New Issue
Block a user