From 7630c11f326025ef67a99db913de9fafcfc0704d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 25 May 2021 13:59:52 +0100 Subject: [PATCH] [Wav2Vec2] SpecAugment Fast (#11764) * first try * finish --- .../models/wav2vec2/modeling_wav2vec2.py | 106 +++++++++--------- tests/test_modeling_wav2vec2.py | 29 +---- 2 files changed, 53 insertions(+), 82 deletions(-) diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index e55e6179ed..cd9183c427 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -48,71 +48,67 @@ def _compute_mask_indices( shape: Tuple[int, int], mask_prob: float, mask_length: int, - attention_mask: Optional[torch.Tensor] = None, + device: torch.device, min_masks: int = 0, -) -> np.ndarray: +) -> torch.tensor: """ - Computes random mask spans for a given shape + Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for + ASR `__. Args: shape: the the shape for which to compute masks. should be of size 2 where first element is batch size and 2nd is timesteps - attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by number of timesteps divided by length of mask span to mask approximately this percentage of all elements. however due to overlaps, the actual number will be smaller (unless no_overlap is True) mask_length: size of the mask min_masks: minimum number of masked spans - Adapted from `fairseq's data_utils.py - `__. """ - bsz, all_sz = shape - mask = np.full((bsz, all_sz), False) + batch_size, sequence_length = shape - all_num_mask = int( - # add a random number for probabilistic rounding - mask_prob * all_sz / float(mask_length) - + np.random.rand() + if mask_length < 1: + raise ValueError("`mask_length` has to be bigger than 0.") + + if mask_length > sequence_length: + raise ValueError( + f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`" + ) + + # compute number of masked spans in batch + num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item()) + num_masked_spans = max(num_masked_spans, min_masks) + + # make sure num masked indices <= sequence_length + if num_masked_spans * mask_length > sequence_length: + num_masked_spans = sequence_length // mask_length + + # SpecAugment mask to fill + spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool) + + # uniform distribution to sample from, make sure that offset samples are < sequence_length + uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device) + + # get random indices to mask + spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans) + + # expand masked indices to masked spans + spec_aug_mask_idxs = ( + spec_aug_mask_idxs.unsqueeze(dim=-1) + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) ) + offsets = ( + torch.arange(mask_length, device=device)[None, None, :] + .expand((batch_size, num_masked_spans, mask_length)) + .reshape(batch_size, num_masked_spans * mask_length) + ) + spec_aug_mask_idxs = spec_aug_mask_idxs + offsets - all_num_mask = max(min_masks, all_num_mask) + # scatter indices to mask + spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True) - mask_idcs = [] - padding_mask = attention_mask.ne(1) if attention_mask is not None else None - for i in range(bsz): - if padding_mask is not None: - sz = all_sz - padding_mask[i].long().sum().item() - num_mask = int( - # add a random number for probabilistic rounding - mask_prob * sz / float(mask_length) - + np.random.rand() - ) - num_mask = max(min_masks, num_mask) - else: - sz = all_sz - num_mask = all_num_mask - - lengths = np.full(num_mask, mask_length) - - if sum(lengths) == 0: - lengths[0] = min(mask_length, sz - 1) - - min_len = min(lengths) - if sz - min_len <= num_mask: - min_len = sz - num_mask - 1 - - mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) - mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])]) - mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) - - min_len = min([len(m) for m in mask_idcs]) - for i, mask_idc in enumerate(mask_idcs): - if len(mask_idc) > min_len: - mask_idc = np.random.choice(mask_idc, min_len, replace=False) - mask[i, mask_idc] = True - - return mask + return spec_aug_mask class Wav2Vec2NoLayerNormConvLayer(nn.Module): @@ -847,21 +843,21 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): if self.config.mask_time_prob > 0: mask_time_indices = _compute_mask_indices( (batch_size, sequence_length), - self.config.mask_time_prob, - self.config.mask_time_length, - attention_mask=attention_mask, + mask_prob=self.config.mask_time_prob, + mask_length=self.config.mask_time_length, + device=hidden_states.device, min_masks=2, ) - hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) + hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype) # apply SpecAugment along feature axis if self.config.mask_feature_prob > 0: mask_feature_indices = _compute_mask_indices( (batch_size, hidden_size), - self.config.mask_feature_prob, - self.config.mask_feature_length, + mask_prob=self.config.mask_feature_prob, + mask_length=self.config.mask_feature_length, + device=hidden_states.device, ) - mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device) hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 encoder_outputs = self.encoder( diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index f2bb897e55..c43515df0d 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase): mask_prob = 0.5 mask_length = 1 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) - attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) - attention_mask[:, -sequence_length // 2 :] = 0 - - mask = _compute_mask_indices( - (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask - ) - - self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)]) - def test_compute_mask_indices_overlap(self): batch_size = 4 sequence_length = 60 mask_prob = 0.5 mask_length = 4 - mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) + mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device) # because of overlap there is a range of possible masks for batch_sum in mask.sum(axis=-1): @@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase): list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), ) - attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long) - attention_mask[:, -sequence_length // 2 :] = 0 - - mask = _compute_mask_indices( - (batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask - ) - - # because of overlap there is a range of possible masks - for batch_sum in mask.sum(axis=-1): - self.assertIn( - int(batch_sum), - list( - range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2)) - ), - ) - @require_torch @slow