[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)
* fix_torch_device_generate_test * remove @ * finish * correct script * correct script
This commit is contained in:
committed by
GitHub
parent
6e87010060
commit
b4b562d834
@@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|||||||
|
|
||||||
batch_size = batch["input_values"].shape[0]
|
batch_size = batch["input_values"].shape[0]
|
||||||
|
|
||||||
|
attention_mask = None
|
||||||
if batch["attention_mask"] is not None:
|
if batch["attention_mask"] is not None:
|
||||||
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
||||||
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
||||||
@@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|||||||
batch["sampled_negative_indices"] = _sample_negative_indices(
|
batch["sampled_negative_indices"] = _sample_negative_indices(
|
||||||
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
||||||
self.model.config.num_negatives,
|
self.model.config.num_negatives,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
|
|
||||||
return batch
|
return batch
|
||||||
|
|||||||
@@ -174,7 +174,7 @@ def _compute_mask_indices(
|
|||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
|
def _sample_negative_indices(features_shape: Tuple, num_negatives: int, attention_mask: Optional[np.ndarray] = None):
|
||||||
"""
|
"""
|
||||||
Sample `num_negatives` vectors from feature vectors.
|
Sample `num_negatives` vectors from feature vectors.
|
||||||
"""
|
"""
|
||||||
@@ -186,11 +186,13 @@ def _sample_negative_indices(features_shape: Tuple, num_negatives: int):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# get `num_negatives` random vector indices from the same utterance
|
# get `num_negatives` random vector indices from the same utterance
|
||||||
sampled_negative_indices = np.random.randint(
|
sampled_negative_indices = []
|
||||||
low=0,
|
for batch_idx in range(batch_size):
|
||||||
high=sequence_length - 1,
|
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
||||||
size=(batch_size, num_negatives * sequence_length),
|
sampled_indices_slice = np.random.randint(0, high, size=(num_negatives * sequence_length,))
|
||||||
)
|
sampled_negative_indices.append(sampled_indices_slice)
|
||||||
|
|
||||||
|
sampled_negative_indices = np.asarray(sampled_negative_indices, dtype=np.int32)
|
||||||
|
|
||||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||||
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
|
feature_indices = np.broadcast_to(np.arange(sequence_length)[:, None], (sequence_length, num_negatives)).flatten()
|
||||||
|
|||||||
@@ -877,6 +877,18 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
|
|
||||||
return input_lengths
|
return input_lengths
|
||||||
|
|
||||||
|
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||||
|
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
||||||
|
batch_size = attention_mask.shape[0]
|
||||||
|
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
|
||||||
|
)
|
||||||
|
# these two operations makes sure that all values before the output lengths idxs are attended to
|
||||||
|
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
|
||||||
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||||
@@ -1044,19 +1056,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
extract_features = extract_features.transpose(1, 2)
|
extract_features = extract_features.transpose(1, 2)
|
||||||
|
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# compute real output lengths according to convolution formula
|
# compute reduced attention_mask correponding to feature vectors
|
||||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||||
|
|
||||||
attention_mask = torch.zeros(
|
|
||||||
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
|
|
||||||
)
|
|
||||||
|
|
||||||
# these two operations makes sure that all values
|
|
||||||
# before the output lengths indices are attended to
|
|
||||||
attention_mask[
|
|
||||||
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
|
|
||||||
] = 1
|
|
||||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||||
hidden_states = self._mask_hidden_states(
|
hidden_states = self._mask_hidden_states(
|
||||||
@@ -1111,7 +1112,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _sample_negatives(features: torch.FloatTensor, num_negatives: int):
|
def _sample_negatives(
|
||||||
|
features: torch.FloatTensor, num_negatives: int, attention_mask: Optional[torch.LongTensor] = None
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Sample `num_negatives` vectors from feature vectors.
|
Sample `num_negatives` vectors from feature vectors.
|
||||||
"""
|
"""
|
||||||
@@ -1125,12 +1128,15 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# get `num_negatives` random vector indices from the same utterance
|
# get `num_negatives` random vector indices from the same utterance
|
||||||
sampled_negative_indices = torch.randint(
|
sampled_negative_indices = []
|
||||||
low=0,
|
for batch_idx in range(batch_size):
|
||||||
high=sequence_length - 1,
|
high = attention_mask[batch_idx].sum() - 1 if attention_mask is not None else sequence_length - 1
|
||||||
size=(batch_size, num_negatives * sequence_length),
|
sampled_indices_slice = torch.randint(
|
||||||
device=features.device,
|
0, high, size=(num_negatives * sequence_length,), device=features.device
|
||||||
)
|
)
|
||||||
|
sampled_negative_indices.append(sampled_indices_slice)
|
||||||
|
|
||||||
|
sampled_negative_indices = torch.stack(sampled_negative_indices)
|
||||||
|
|
||||||
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
# generate indices of the positive vectors themselves, repeat them `num_negatives` times
|
||||||
feature_indices = (
|
feature_indices = (
|
||||||
@@ -1263,7 +1269,14 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
|||||||
if self.training:
|
if self.training:
|
||||||
# for training, we sample negatives
|
# for training, we sample negatives
|
||||||
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
# 3. sample K negatives (distractors) quantized states for contrastive loss
|
||||||
negative_quantized_features = self._sample_negatives(quantized_features, self.config.num_negatives)
|
# if attention_mask is passed, make sure that padded feature vectors cannot be sampled
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute reduced attention_mask correponding to feature vectors
|
||||||
|
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||||
|
|
||||||
|
negative_quantized_features = self._sample_negatives(
|
||||||
|
quantized_features, self.config.num_negatives, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
# 4. compute logits, corresponding to `logs = sim(c_t, [q_t, \sim{q}_t]) / \kappa`
|
||||||
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
# of equation (3) in https://arxiv.org/pdf/2006.11477.pdf
|
||||||
|
|||||||
@@ -306,6 +306,48 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
# => this means that `unique()` yields a single value for `hidden_size` dim
|
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||||
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||||
|
|
||||||
|
def test_sample_negatives_with_attn_mask(self):
|
||||||
|
batch_size = 2
|
||||||
|
sequence_length = 10
|
||||||
|
hidden_size = 4
|
||||||
|
num_negatives = 3
|
||||||
|
|
||||||
|
features = (np.arange(sequence_length * hidden_size) // hidden_size).reshape(
|
||||||
|
sequence_length, hidden_size
|
||||||
|
) # each value in vector consits of same value
|
||||||
|
|
||||||
|
# second half of last input tensor is padded
|
||||||
|
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int8)
|
||||||
|
attention_mask[-1, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
|
forbidden_indices = (
|
||||||
|
np.arange(sequence_length // 2, sequence_length, dtype=np.int32) + (batch_size - 1) * sequence_length
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
features = np.broadcast_to(features[None, :], (batch_size, sequence_length, hidden_size))
|
||||||
|
|
||||||
|
negative_indices = _sample_negative_indices(features.shape, num_negatives, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
# make sure that no padding tokens are sampled
|
||||||
|
self.assertTrue(all([idx not in negative_indices for idx in forbidden_indices]))
|
||||||
|
|
||||||
|
features = features.reshape(-1, hidden_size) # BTC => (BxT)C
|
||||||
|
# take negative vectors from sampled indices
|
||||||
|
sampled_negatives = features[negative_indices.reshape(-1)]
|
||||||
|
negatives = sampled_negatives.reshape(batch_size, sequence_length, num_negatives, hidden_size).transpose(
|
||||||
|
2, 0, 1, 3
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||||
|
|
||||||
|
# make sure no negatively sampled vector is actually a positive one
|
||||||
|
for negative in negatives:
|
||||||
|
self.assertTrue(((negative - features.reshape(negative.shape)) == 0).sum() == 0.0)
|
||||||
|
|
||||||
|
# make sure that full vectors are sampled and not just slices of vectors
|
||||||
|
# => this means that `unique()` yields a single value for `hidden_size` dim
|
||||||
|
self.assertTrue(np.unique(negatives, axis=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||||
|
|
||||||
|
|
||||||
@require_flax
|
@require_flax
|
||||||
@require_datasets
|
@require_datasets
|
||||||
|
|||||||
@@ -633,6 +633,37 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||||
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||||
|
|
||||||
|
def test_sample_negatives_with_attn_mask(self):
|
||||||
|
batch_size = 2
|
||||||
|
sequence_length = 10
|
||||||
|
hidden_size = 4
|
||||||
|
num_negatives = 3
|
||||||
|
|
||||||
|
# second half of last input tensor is padded
|
||||||
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||||
|
attention_mask[-1, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
|
features = (torch.arange(sequence_length * hidden_size, device=torch_device) // hidden_size).view(
|
||||||
|
sequence_length, hidden_size
|
||||||
|
) # each value in vector consits of same value
|
||||||
|
features = features[None, :].expand(batch_size, sequence_length, hidden_size).contiguous()
|
||||||
|
|
||||||
|
# replace masked feature vectors with -100 to test that those are not sampled
|
||||||
|
features = torch.where(attention_mask[:, :, None].expand(features.shape).bool(), features, -100)
|
||||||
|
|
||||||
|
negatives = Wav2Vec2ForPreTraining._sample_negatives(features, num_negatives, attention_mask=attention_mask)
|
||||||
|
|
||||||
|
self.assertTrue((negatives >= 0).all().item())
|
||||||
|
|
||||||
|
self.assertTrue(negatives.shape == (num_negatives, batch_size, sequence_length, hidden_size))
|
||||||
|
|
||||||
|
# make sure no negatively sampled vector is actually a positive one
|
||||||
|
for negative in negatives:
|
||||||
|
self.assertTrue(((negative - features) == 0).sum() == 0.0)
|
||||||
|
|
||||||
|
# make sure that full vectors are sampled and not values of vectors => this means that `unique()` yields a single value for `hidden_size` dim
|
||||||
|
self.assertTrue(negatives.unique(dim=-1).shape, (num_negatives, batch_size, sequence_length, 1))
|
||||||
|
|
||||||
|
|
||||||
@require_torch
|
@require_torch
|
||||||
@require_datasets
|
@require_datasets
|
||||||
|
|||||||
Reference in New Issue
Block a user