committed by
GitHub
parent
f086652b16
commit
7630c11f32
@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
|
||||
)
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(
|
||||
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user