[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)
* fix_torch_device_generate_test * remove @ * finish * correct script * correct script
This commit is contained in:
committed by
GitHub
parent
6e87010060
commit
b4b562d834
@@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
|
||||
batch_size = batch["input_values"].shape[0]
|
||||
|
||||
attention_mask = None
|
||||
if batch["attention_mask"] is not None:
|
||||
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
||||
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
||||
@@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
batch["sampled_negative_indices"] = _sample_negative_indices(
|
||||
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
||||
self.model.config.num_negatives,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
@@ -174,7 +174,7 @@ def _compute_mask_indices(
|
||||
return spec_aug_mask
|
||||
|
||||
|
||||
def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
|
||||
def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None):
|
||||
"""
|
||||
Sample `num_negatives` vectors from feature vectors.
|
||||
"""
|
||||
@@ -186,11 +186,13 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
|
||||
)
|
||||
|
||||
# get `num_negatives` random vector indices from the same utterance
|
||||
sampled_negative_indices = np.random.randint(
|
||||
low=0,
|
||||
high=sequence_length - 1,
|
||||
size=(batch_size, num_negatives * sequence_length),
|
||||
)
|
||||
sampled_negative_indices = []
|
||||
for batch_idx in range(batch_size):
|
||||
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
||||
sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,))
|
||||
sampled_negative_indices.append(sampled_indices_slice)
|
||||
|
||||
sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32)
|
||||
|
||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
|
||||
|
||||
@@ -877,6 +877,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
||||
)
|
||||
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||||
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
return attention_mask
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||
@@ -1044,19 +1056,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
extract_features = extract_features.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
@@ -1111,7 +1112,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@staticmethod
|
||||
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
|
||||
def _sample_negatives(
|
||||
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
|
||||
):
|
||||
"""
|
||||
Sample `num_negatives` vectors from feature vectors.
|
||||
"""
|
||||
@@ -1125,12 +1128,15 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
|
||||
with torch.no_grad():
|
||||
# get `num_negatives` random vector indices from the same utterance
|
||||
sampled_negative_indices = torch.randint(
|
||||
low=0,
|
||||
high=sequence_length - 1,
|
||||
size=(batch_size, num_negatives * sequence_length),
|
||||
device=features.device,
|
||||
sampled_negative_indices = []
|
||||
for batch_idx in range(batch_size):
|
||||
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
||||
sampled_indices_slice = torch.randint(
|
||||
0, high, size=(num_negatives * sequence_length,), device=features.device
|
||||
)
|
||||
sampled_negative_indices.append(sampled_indices_slice)
|
||||
|
||||
sampled_negative_indices = torch.stack(sampled_negative_indices)
|
||||
|
||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||
feature_indices = (
|
||||
@@ -1263,7 +1269,14 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
if self.training:
|
||||
# for training, we sample negatives
|
||||
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
||||
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
|
||||
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
|
||||
negative_quantized_features = self._sample_negatives(
|
||||
quantized_features, self.config.num_negatives, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
||||
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
||||
|
||||
@@ -306,6 +306,48 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
forbidden_indices = (
|
||||
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
|
||||
).tolist()
|
||||
|
||||
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
||||
|
||||
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
|
||||
|
||||
# make sure that no padding tokens are sampled
|
||||
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
|
||||
|
||||
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
||||
# take negative vectors from sampled indices
|
||||
sampled_negatives = features[negative_indices.reshape(-1)]
|
||||
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
||||
2, 0, 1, 3
|
||||
)
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not just slices of vectors
|
||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_flax
|
||||
@require_datasets
|
||||
|
||||
@@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
def test_sample_negatives_with_attn_mask(self):
|
||||
batch_size = 2
|
||||
sequence_length = 10
|
||||
hidden_size = 4
|
||||
num_negatives = 3
|
||||
|
||||
# second half of last input tensor is padded
|
||||
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||
attention_mask[-1, sequence_length // 2 :] = 0
|
||||
|
||||
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||
sequence_length, hidden_size
|
||||
) # each value in vector consits of same value
|
||||
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||
|
||||
# replace masked feature vectors with -100 to test that those are not sampled
|
||||
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||
|
||||
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
||||
|
||||
self.assertTrue((negatives >= 0).all().item())
|
||||
|
||||
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||
|
||||
# make sure no negatively sampled vector is actually a positive one
|
||||
for negative in negatives:
|
||||
self.assertTrue(((negative - features) == 0).sum() == 0.0)
|
||||
|
||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||
|
||||
|
||||
@require_torch
|
||||
@require_datasets
|
||||
|
||||
Reference in New Issue
Block a user