[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)

* fix_torch_device_generate_test

* remove @

* start adding tests

* correct wav2vec2 pretraining

* up

* up

Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-07-15 21:40:25 +01:00
committed by GitHub
parent 5f2791c7c1
commit 2e9fb13fb1
7 changed files with 98 additions and 5 deletions

View File

@@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
)
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
batch_size = batch["input_values"].shape[0]
if batch["attention_mask"] is not None:
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
# these two operations makes sure that all values
# before the output lengths indices are attended to
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
# sample randomly masked indices
batch["mask_time_indices"] = _compute_mask_indices(
(batch["input_values"].shape[0], mask_indices_seq_length),
(batch_size, mask_indices_seq_length),
self.model.config.mask_time_prob,
self.model.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)