[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)

* fix_torch_device_generate_test

* remove @

* finish

* correct script

* correct script
This commit is contained in:
Patrick von Platen
2021-07-16 18:07:08 +01:00
committed by GitHub
parent 6e87010060
commit b4b562d834
5 changed files with 117 additions and 27 deletions

View File

@@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
batch_size = batch["input_values"].shape[0]
attention_mask = None
if batch["attention_mask"] is not None:
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
@@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
batch["sampled_negative_indices"] = _sample_negative_indices(
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
self.model.config.num_negatives,
attention_mask=attention_mask,
)
return batch