[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)

* fix_torch_device_generate_test

* remove @

* finish

* correct script

* correct script
This commit is contained in:
Patrick von Platen
2021-07-16 18:07:08 +01:00
committed by GitHub
parent 6e87010060
commit b4b562d834
5 changed files with 117 additions and 27 deletions

View File

@@ -306,6 +306,48 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
sequence_length, hidden_size
) # each value in vector consits of same value
# second half of last input tensor is padded
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
attention_mask[-1, sequence_length // 2 :] = 0
forbidden_indices = (
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
).tolist()
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
# make sure that no padding tokens are sampled
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
# take negative vectors from sampled indices
sampled_negatives = features[negative_indices.reshape(-1)]
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
2, 0, 1, 3
)
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not just slices of vectors
# => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_flax
@require_datasets

View File

@@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
def test_sample_negatives_with_attn_mask(self):
batch_size = 2
sequence_length = 10
hidden_size = 4
num_negatives = 3
# second half of last input tensor is padded
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
attention_mask[-1, sequence_length // 2 :] = 0
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
sequence_length, hidden_size
) # each value in vector consits of same value
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
# replace masked feature vectors with -100 to test that those are not sampled
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
self.assertTrue((negatives >= 0).all().item())
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
# make sure no negatively sampled vector is actually a positive one
for negative in negatives:
self.assertTrue(((negative - features) == 0).sum() == 0.0)
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
@require_torch
@require_datasets