[Wav2Vec2] Padded vectors should not allowed to be sampled (#12764)
* fix_torch_device_generate_test * remove @ * finish * correct script * correct script
This commit is contained in:
committed by
GitHub
parent
6e87010060
commit
b4b562d834
@@ -176,6 +176,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
|
||||
batch_size = batch["input_values"].shape[0]
|
||||
|
||||
attention_mask = None
|
||||
if batch["attention_mask"] is not None:
|
||||
output_lengths = self.model._get_feat_extract_output_lengths(batch["attention_mask"].sum(-1))
|
||||
attention_mask = np.zeros((batch_size, mask_indices_seq_length), dtype=np.int8)
|
||||
@@ -198,6 +199,7 @@ class FlaxDataCollatorForWav2Vec2Pretraining:
|
||||
batch["sampled_negative_indices"] = _sample_negative_indices(
|
||||
(batch["mask_time_indices"].shape + (self.model.config.proj_codevector_dim,)),
|
||||
self.model.config.num_negatives,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
return batch
|
||||
|
||||
Reference in New Issue
Block a user