[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)

* fix_torch_device_generate_test

* remove @

* start adding tests

* correct wav2vec2 pretraining

* up

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-07-15 21:40:25 +01:00
committed by GitHub
parent 5f2791c7c1
commit 2e9fb13fb1
7 changed files with 98 additions and 5 deletions

View File

@@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
batch_size = batch["input_values"].shape[0]
if batch["attention_mask"] is not None:
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)

View File

@@ -172,12 +172,33 @@ class DataCollatorForWav2Vec2Pretraining:
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
batch_size = batch["input_values"].shape[0]
# make sure that no loss is computed on padded inputs
if batch["attention_mask"] is not None:
# compute real output lengths according to convolution formula
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1)).to(
torch.long
)
attention_mask = torch.zeros(
(batch_size, mask_indices_seq_length), dtype=torch.long, device=batch["input_values"].device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[
(torch.arange(attention_mask.shape[0], device=batch["input_values"].device), output_lengths - 1)
] = 1
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
device=batch["input_values"].device,
attention_mask=attention_mask,
min_masks=2,
)