committed by
GitHub
parent
f086652b16
commit
7630c11f32
@@ -48,71 +48,67 @@ def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
device: torch.device,
|
||||
min_masks: int = 0,
|
||||
) -> np.ndarray:
|
||||
) -> torch.tensor:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
Computes random mask spans for a given shape. Used to implement `SpecAugment: A Simple Data Augmentation Method for
|
||||
ASR <https://arxiv.org/abs/1904.08779>`__.
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
|
||||
Adapted from `fairseq's data_utils.py
|
||||
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
|
||||
"""
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
batch_size, sequence_length = shape
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
if mask_length < 1:
|
||||
raise ValueError("`mask_length` has to be bigger than 0.")
|
||||
|
||||
if mask_length > sequence_length:
|
||||
raise ValueError(
|
||||
f"`mask_length` has to be smaller than `sequence_length`, but got `mask_length`: {mask_length} and `sequence_length`: {sequence_length}`"
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
# compute number of masked spans in batch
|
||||
num_masked_spans = int(mask_prob * sequence_length / mask_length + torch.rand((1,)).item())
|
||||
num_masked_spans = max(num_masked_spans, min_masks)
|
||||
|
||||
mask_idcs = []
|
||||
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
# make sure num masked indices <= sequence_length
|
||||
if num_masked_spans * mask_length > sequence_length:
|
||||
num_masked_spans = sequence_length // mask_length
|
||||
|
||||
# SpecAugment mask to fill
|
||||
spec_aug_mask = torch.zeros((batch_size, sequence_length), device=device, dtype=torch.bool)
|
||||
|
||||
# uniform distribution to sample from, make sure that offset samples are < sequence_length
|
||||
uniform_dist = torch.ones((batch_size, sequence_length - (mask_length - 1)), device=device)
|
||||
|
||||
# get random indices to mask
|
||||
spec_aug_mask_idxs = torch.multinomial(uniform_dist, num_masked_spans)
|
||||
|
||||
# expand masked indices to masked spans
|
||||
spec_aug_mask_idxs = (
|
||||
spec_aug_mask_idxs.unsqueeze(dim=-1)
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
offsets = (
|
||||
torch.arange(mask_length, device=device)[None, None, :]
|
||||
.expand((batch_size, num_masked_spans, mask_length))
|
||||
.reshape(batch_size, num_masked_spans * mask_length)
|
||||
)
|
||||
spec_aug_mask_idxs = spec_aug_mask_idxs + offsets
|
||||
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
# scatter indices to mask
|
||||
spec_aug_mask = spec_aug_mask.scatter(1, spec_aug_mask_idxs, True)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
return spec_aug_mask
|
||||
|
||||
|
||||
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
@@ -847,21 +843,21 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
if self.config.mask_time_prob > 0:
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
self.config.mask_time_prob,
|
||||
self.config.mask_time_length,
|
||||
attention_mask=attention_mask,
|
||||
mask_prob=self.config.mask_time_prob,
|
||||
mask_length=self.config.mask_time_length,
|
||||
device=hidden_states.device,
|
||||
min_masks=2,
|
||||
)
|
||||
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
hidden_states[mask_time_indices] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
# apply SpecAugment along feature axis
|
||||
if self.config.mask_feature_prob > 0:
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
self.config.mask_feature_prob,
|
||||
self.config.mask_feature_length,
|
||||
mask_prob=self.config.mask_feature_prob,
|
||||
mask_length=self.config.mask_feature_length,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
|
||||
@@ -478,26 +478,17 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
mask_prob = 0.5
|
||||
mask_length = 1
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length for _ in range(batch_size)])
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
self.assertListEqual(mask.sum(axis=-1).tolist(), [mask_prob * sequence_length // 2 for _ in range(batch_size)])
|
||||
|
||||
def test_compute_mask_indices_overlap(self):
|
||||
batch_size = 4
|
||||
sequence_length = 60
|
||||
mask_prob = 0.5
|
||||
mask_length = 4
|
||||
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length)
|
||||
mask = _compute_mask_indices((batch_size, sequence_length), mask_prob, mask_length, torch_device)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
@@ -506,22 +497,6 @@ class Wav2Vec2UtilsTest(unittest.TestCase):
|
||||
list(range(int(mask_prob // mask_length * sequence_length), int(mask_prob * sequence_length))),
|
||||
)
|
||||
|
||||
attention_mask = torch.ones((batch_size, sequence_length), device=torch_device, dtype=torch.long)
|
||||
attention_mask[:, -sequence_length // 2 :] = 0
|
||||
|
||||
mask = _compute_mask_indices(
|
||||
(batch_size, sequence_length), mask_prob, mask_length, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# because of overlap there is a range of possible masks
|
||||
for batch_sum in mask.sum(axis=-1):
|
||||
self.assertIn(
|
||||
int(batch_sum),
|
||||
list(
|
||||
range(int(mask_prob // mask_length * sequence_length // 2), int(mask_prob * sequence_length // 2))
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@require_torch
|
||||
@slow
|
||||
|
||||
Reference in New Issue
Block a user