[Wav2Vec2] Correctly pad mask indices for PreTraining (#12748)
* fix_torch_device_generate_test * remove @ * start adding tests * correct wav2vec2 pretraining * up * up Co-authored-by: Patrick von Platen <patrick@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
5f2791c7c1
commit
2e9fb13fb1
@@ -174,11 +174,23 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
)
|
||||
mask_indices_seq_length = self.model._get_feat_extract_output_lengths(batch["input_values"].shape[-1])
|
||||
|
||||
batch_size = batch["input_values"].shape[0]
|
||||
|
||||
if batch["attention_mask"] is not None:
|
||||
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
||||
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[(np.arange(attention_mask.shape[0]), output_lengths - 1)] = 1
|
||||
attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")
|
||||
|
||||
# sample randomly masked indices
|
||||
batch["mask_time_indices"] = _compute_mask_indices(
|
||||
(batch["input_values"].shape[0], mask_indices_seq_length),
|
||||
(batch_size, mask_indices_seq_length),
|
||||
self.model.config.mask_time_prob,
|
||||
self.model.config.mask_time_length,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user