[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)
* fix_torch_device_generate_test * remove @ * finish * correct script * correct script
This commit is contained in:
committed by
GitHub
parent
6e87010060
commit
b4b562d834
@@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
||||
|
||||
self.assertTrue((negatives >= 0).all().item())
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_datasets
|
||||
|
||||
Reference in New Issue
Block a user