[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)
* fix_torch_device_generate_test * remove @ * start adding tests * correct wav2vec2 pretraining * up * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
5f2791c7c1
commit
2e9fb13fb1
@@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
|||||||
)
|
)
|
||||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||||
|
|
||||||
|
batch_size = batch["input_values"].shape[0]
|
||||||
|
|
||||||
|
if batch["attention_mask"] is not None:
|
||||||
|
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
||||||
|
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
||||||
|
|
||||||
|
# these two operations makes sure that all values
|
||||||
|
# before the output lengths indices are attended to
|
||||||
|
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
|
||||||
|
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
||||||
|
|
||||||
# sample randomly masked indices
|
# sample randomly masked indices
|
||||||
batch["mask_time_indices"] = _compute_mask_indices(
|
batch["mask_time_indices"] = _compute_mask_indices(
|
||||||
(batch["input_values"].shape[0], mask_indices_seq_length),
|
(batch_size, mask_indices_seq_length),
|
||||||
self.model.config.mask_time_prob,
|
self.model.config.mask_time_prob,
|
||||||
self.model.config.mask_time_length,
|
self.model.config.mask_time_length,
|
||||||
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -172,12 +172,33 @@ class DataCollatorForWav2Vec2Pretraining:
|
|||||||
)
|
)
|
||||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||||
|
|
||||||
|
batch_size = batch["input_values"].shape[0]
|
||||||
|
|
||||||
|
# make sure that no loss is computed on padded inputs
|
||||||
|
if batch["attention_mask"] is not None:
|
||||||
|
# compute real output lengths according to convolution formula
|
||||||
|
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to(
|
||||||
|
torch.long
|
||||||
|
)
|
||||||
|
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
(batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device
|
||||||
|
)
|
||||||
|
|
||||||
|
# these two operations makes sure that all values
|
||||||
|
# before the output lengths indices are attended to
|
||||||
|
attention_mask[
|
||||||
|
(torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1)
|
||||||
|
] = 1
|
||||||
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
|
|
||||||
# sample randomly masked indices
|
# sample randomly masked indices
|
||||||
batch["mask_time_indices"] = _compute_mask_indices(
|
batch["mask_time_indices"] = _compute_mask_indices(
|
||||||
(batch["input_values"].shape[0], mask_indices_seq_length),
|
(batch_size, mask_indices_seq_length),
|
||||||
self.model.config.mask_time_prob,
|
self.model.config.mask_time_prob,
|
||||||
self.model.config.mask_time_length,
|
self.model.config.mask_time_length,
|
||||||
device=batch["input_values"].device,
|
device=batch["input_values"].device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ def _compute_mask_indices(
|
|||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
mask_length: int,
|
mask_length: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
attention_mask: Optional[torch.tensor] = None,
|
||||||
min_masks: int = 0,
|
min_masks: int = 0,
|
||||||
) -> torch.tensor:
|
) -> torch.tensor:
|
||||||
"""
|
"""
|
||||||
@@ -813,7 +814,10 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
|
|
||||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
||||||
def _mask_hidden_states(
|
def _mask_hidden_states(
|
||||||
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
|
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
|
||||||
@@ -836,6 +840,7 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
mask_prob=self.config.mask_time_prob,
|
mask_prob=self.config.mask_time_prob,
|
||||||
mask_length=self.config.mask_time_length,
|
mask_length=self.config.mask_time_length,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||||
@@ -847,6 +852,7 @@ class HubertModel(HubertPreTrainedModel):
|
|||||||
mask_prob=self.config.mask_feature_prob,
|
mask_prob=self.config.mask_feature_prob,
|
||||||
mask_length=self.config.mask_feature_length,
|
mask_length=self.config.mask_feature_length,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||||
|
|
||||||
|
|||||||
@@ -107,6 +107,7 @@ def _compute_mask_indices(
|
|||||||
shape: Tuple[int, int],
|
shape: Tuple[int, int],
|
||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
mask_length: int,
|
mask_length: int,
|
||||||
|
attention_mask: Optional[np.ndarray] = None,
|
||||||
min_masks: int = 0,
|
min_masks: int = 0,
|
||||||
) -> np.ndarray:
|
) -> np.ndarray:
|
||||||
"""
|
"""
|
||||||
@@ -166,6 +167,10 @@ def _compute_mask_indices(
|
|||||||
# scatter indices to mask
|
# scatter indices to mask
|
||||||
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
np.put_along_axis(spec_aug_mask, spec_aug_mask_idxs, 1, -1)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# make sure padded input ids cannot be masked
|
||||||
|
spec_aug_mask = np.where(attention_mask, spec_aug_mask, False)
|
||||||
|
|
||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -873,6 +878,7 @@ class FlaxWav2Vec2Module(nn.Module):
|
|||||||
"""
|
"""
|
||||||
extract_features = self.feature_extractor(input_values)
|
extract_features = self.feature_extractor(input_values)
|
||||||
|
|
||||||
|
# make sure that no loss is computed on padded inputs
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
# compute real output lengths according to convolution formula
|
# compute real output lengths according to convolution formula
|
||||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
|
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))
|
||||||
|
|||||||
@@ -120,6 +120,7 @@ def _compute_mask_indices(
|
|||||||
mask_prob: float,
|
mask_prob: float,
|
||||||
mask_length: int,
|
mask_length: int,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
|
attention_mask: Optional[torch.tensor] = None,
|
||||||
min_masks: int = 0,
|
min_masks: int = 0,
|
||||||
) -> torch.tensor:
|
) -> torch.tensor:
|
||||||
"""
|
"""
|
||||||
@@ -179,6 +180,10 @@ def _compute_mask_indices(
|
|||||||
# scatter indices to mask
|
# scatter indices to mask
|
||||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# make sure padded input ids cannot be masked
|
||||||
|
spec_aug_mask = torch.where(attention_mask.bool(), spec_aug_mask, False)
|
||||||
|
|
||||||
return spec_aug_mask
|
return spec_aug_mask
|
||||||
|
|
||||||
|
|
||||||
@@ -950,7 +955,10 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def _mask_hidden_states(
|
def _mask_hidden_states(
|
||||||
self, hidden_states: torch.FloatTensor, mask_time_indices: Optional[torch.FloatTensor] = None
|
self,
|
||||||
|
hidden_states: torch.FloatTensor,
|
||||||
|
mask_time_indices: Optional[torch.FloatTensor] = None,
|
||||||
|
attention_mask: Optional[torch.LongTensor] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
|
Masks extracted features along time axis and/or along feature axis according to `SpecAugment
|
||||||
@@ -973,6 +981,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
mask_prob=self.config.mask_time_prob,
|
mask_prob=self.config.mask_time_prob,
|
||||||
mask_length=self.config.mask_time_length,
|
mask_length=self.config.mask_time_length,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
min_masks=2,
|
min_masks=2,
|
||||||
)
|
)
|
||||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||||
@@ -984,6 +993,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
mask_prob=self.config.mask_feature_prob,
|
mask_prob=self.config.mask_feature_prob,
|
||||||
mask_length=self.config.mask_feature_length,
|
mask_length=self.config.mask_feature_length,
|
||||||
device=hidden_states.device,
|
device=hidden_states.device,
|
||||||
|
attention_mask=attention_mask,
|
||||||
)
|
)
|
||||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||||
|
|
||||||
@@ -1049,7 +1059,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
|
|
||||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||||
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
hidden_states = self._mask_hidden_states(
|
||||||
|
hidden_states, mask_time_indices=mask_time_indices, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
|||||||
@@ -245,6 +245,24 @@ class FlaxWav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
for batch_sum in mask.sum(axis=-1):
|
for batch_sum in mask.sum(axis=-1):
|
||||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||||
|
|
||||||
|
def test_compute_mask_indices_attn_mask_overlap(self):
|
||||||
|
batch_size = 4
|
||||||
|
sequence_length = 80
|
||||||
|
mask_prob = 0.5
|
||||||
|
mask_length = 4
|
||||||
|
|
||||||
|
attention_mask = np.ones((batch_size, sequence_length), dtype=np.int32)
|
||||||
|
attention_mask[:2, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
|
mask = _compute_mask_indices(
|
||||||
|
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_sum in mask.sum(axis=-1):
|
||||||
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||||
|
|
||||||
|
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
||||||
|
|
||||||
def test_compute_perplexity(self):
|
def test_compute_perplexity(self):
|
||||||
probs = np.arange(100).reshape(2, 5, 10) / 100
|
probs = np.arange(100).reshape(2, 5, 10) / 100
|
||||||
|
|
||||||
|
|||||||
@@ -580,6 +580,24 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
|||||||
for batch_sum in mask.sum(axis=-1):
|
for batch_sum in mask.sum(axis=-1):
|
||||||
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||||
|
|
||||||
|
def test_compute_mask_indices_attn_mask_overlap(self):
|
||||||
|
batch_size = 4
|
||||||
|
sequence_length = 80
|
||||||
|
mask_prob = 0.5
|
||||||
|
mask_length = 4
|
||||||
|
|
||||||
|
attention_mask = torch.ones((batch_size, sequence_length), dtype=torch.long, device=torch_device)
|
||||||
|
attention_mask[:2, sequence_length // 2 :] = 0
|
||||||
|
|
||||||
|
mask = _compute_mask_indices(
|
||||||
|
(batch_size, sequence_length), mask_prob, mask_length, device=torch_device, attention_mask=attention_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch_sum in mask.sum(axis=-1):
|
||||||
|
self.assertTrue(int(batch_sum) <= mask_prob * sequence_length)
|
||||||
|
|
||||||
|
self.assertTrue(mask[:2, sequence_length // 2 :].sum() == 0)
|
||||||
|
|
||||||
def test_compute_perplexity(self):
|
def test_compute_perplexity(self):
|
||||||
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
|
probs = torch.arange(100, device=torch_device).reshape(2, 5, 10) / 100
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user