[Wav2Vec2] SpecAugment Fast (#11764)

* first try

* finish
This commit is contained in:
Patrick von Platen
2021-05-25 13:59:52 +01:00
committed by GitHub
parent f086652b16
commit 7630c11f32
2 changed files with 53 additions and 82 deletions

View File

@@ -48,71 +48,67 @@ def _compute_mask_indices(
shape: Tuple[int, int], shape: Tuple[int, int],
mask_prob: float, mask_prob: float,
mask_length: int, mask_length: int,
attention_mask: Optional[torch.Tensor] = None, device: torch.device,
min_masks: int = 0, min_masks: int = 0,
) -> np.ndarray: ) -> torch.tensor:
""" """
Computes random mask spans for a given shape Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
ASR <https://arxiv.org/abs/1904.08779>`__.
Args: Args:
shape: the the shape for which to compute masks. shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements. number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True) however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask mask_length: size of the mask
min_masks: minimum number of masked spans min_masks: minimum number of masked spans
Adapted from `fairseq's data_utils.py
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
""" """
bsz, all_sz = shape batch_size, sequence_length = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int( if mask_length < 1:
# add a random number for probabilistic rounding raise ValueError("`mask_length` has to be bigger than 0.")
mask_prob * all_sz / float(mask_length)
+ np.random.rand() if mask_length > sequence_length:
raise ValueError(
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
) )
all_num_mask = max(min_masks, all_num_mask) # compute number of masked spans in batch
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
num_masked_spans = max(num_masked_spans, min_masks)
mask_idcs = [] # make sure num masked indices <= sequence_length
padding_mask = attention_mask.ne(1) if attention_mask is not None else None if num_masked_spans * mask_length > sequence_length:
for i in range(bsz): num_masked_spans = sequence_length // mask_length
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item() # SpecAugment mask to fill
num_mask = int( spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length) # uniform distribution to sample from, make sure that offset samples are < sequence_length
+ np.random.rand() uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
# get random indices to mask
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
# expand masked indices to masked spans
spec_aug_mask_idxs = (
spec_aug_mask_idxs.unsqueeze(dim=-1)
.expand((batch_size, num_masked_spans, mask_length))
.reshape(batch_size, num_masked_spans * mask_length)
) )
num_mask = max(min_masks, num_mask) offsets = (
else: torch.arange(mask_length, device=device)[None, None, :]
sz = all_sz .expand((batch_size, num_masked_spans, mask_length))
num_mask = all_num_mask .reshape(batch_size, num_masked_spans * mask_length)
)
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
lengths = np.full(num_mask, mask_length) # scatter indices to mask
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
if sum(lengths) == 0: return spec_aug_mask
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class Wav2Vec2NoLayerNormConvLayer(nn.Module): class Wav2Vec2NoLayerNormConvLayer(nn.Module):
@@ -847,21 +843,21 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
if self.config.mask_time_prob > 0: if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices( mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length), (batch_size, sequence_length),
self.config.mask_time_prob, mask_prob=self.config.mask_time_prob,
self.config.mask_time_length, mask_length=self.config.mask_time_length,
attention_mask=attention_mask, device=hidden_states.device,
min_masks=2, min_masks=2,
) )
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype) hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along feature axis # apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0: if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices( mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size), (batch_size, hidden_size),
self.config.mask_feature_prob, mask_prob=self.config.mask_feature_prob,
self.config.mask_feature_length, mask_length=self.config.mask_feature_length,
device=hidden_states.device,
) )
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0 hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder( encoder_outputs = self.encoder(

View File

@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
mask_prob = 0.5 mask_prob = 0.5
mask_length = 1 mask_length = 1
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)]) self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
def test_compute_mask_indices_overlap(self): def test_compute_mask_indices_overlap(self):
batch_size = 4 batch_size = 4
sequence_length = 60 sequence_length = 60
mask_prob = 0.5 mask_prob = 0.5
mask_length = 4 mask_length = 4
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length) mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
# because of overlap there is a range of possible masks # because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1): for batch_sum in mask.sum(axis=-1):
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))), list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
) )
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
attention_mask[:, -sequence_length // 2 :] = 0
mask = _compute_mask_indices(
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
)
# because of overlap there is a range of possible masks
for batch_sum in mask.sum(axis=-1):
self.assertIn(
int(batch_sum),
list(
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
),
)
@require_torch @require_torch
@slow @slow