Add Fine-Tuning for Wav2Vec2 (#10145)

* add encode labels function to tokenizer

* start adding finetuning

* init dropout

* upload

* correct convert script

* apply changes

* fix second typo

* make first dummy training run

* adapt convert script

* push confg for comparison

* remove conf

* finish training

* adapt data collator

* add research folder

* update according to fairseq feedback

* some minor corrections

* refactor masking indices a bit

* some minor changes

* clean tokenizer

* finish clean-up

* remove previous logic

* update run script

* correct training

* finish changes

* finish model

* correct bug

* fix training a bit more

* add some tests

* finish gradient checkpointing

* finish example

* correct gradient checkpointing

* improve tokenization method

* revert changes in tokenizer

* revert general change

* adapt fine-tuning

* update

* save intermediate test

* Update README.md

* finish finetuning

* delete conversion script

* Update src/transformers/models/wav2vec2/configuration_wav2vec2.py

* Update src/transformers/models/wav2vec2/processing_wav2vec2.py

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>

* finish wav2vec2 script

* finish wav2vec2 fine-tuning

* finalize test

* correct test

* adapt tests

* finish

* remove test file

Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
Patrick von Platen
2021-03-01 12:13:17 +03:00
committed by GitHub
parent 3c733f3208
commit 0234de8418
14 changed files with 932 additions and 84 deletions

View File

@@ -92,6 +92,33 @@ class Wav2Vec2Config(PretrainedConfig):
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
False`` corresponds to applying layer norm after the attention layer.
freeze_feat_extract_train (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to freeze the weights of the feature extractor when training.
apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
`SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
<https://arxiv.org/abs/1904.08779>`__.
mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_time_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the time axis.
mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
Length of vector span along the feature axis.
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`.
ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
instance of :class:`~transformers.Wav2Vec2ForCTC`.
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
Example::
@@ -116,12 +143,15 @@ class Wav2Vec2Config(PretrainedConfig):
num_attention_heads=12,
intermediate_size=3072,
hidden_act="gelu",
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
hidden_dropout=0.1,
activation_dropout=0.1,
attention_dropout=0.1,
feat_proj_dropout=0.1,
final_dropout=0.1,
layerdrop=0.1,
initializer_range=0.02,
layer_norm_eps=1e-5,
feat_extract_norm="group",
feat_extract_dropout=0.0,
feat_extract_activation="gelu",
conv_dim=(512, 512, 512, 512, 512, 512, 512),
conv_stride=(5, 2, 2, 2, 2, 2, 2),
@@ -130,6 +160,15 @@ class Wav2Vec2Config(PretrainedConfig):
num_conv_pos_embeddings=128,
num_conv_pos_embedding_groups=16,
do_stable_layer_norm=False,
freeze_feat_extract_train=True,
apply_spec_augment=True,
mask_time_prob=0.05,
mask_time_length=10,
mask_feature_prob=0.0,
mask_feature_length=10,
ctc_loss_reduction="sum",
ctc_zero_infinity=False,
gradient_checkpointing=False,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
@@ -138,7 +177,6 @@ class Wav2Vec2Config(PretrainedConfig):
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
self.hidden_size = hidden_size
self.feat_extract_norm = feat_extract_norm
self.feat_extract_dropout = feat_extract_dropout
self.feat_extract_activation = feat_extract_activation
self.conv_dim = list(conv_dim)
self.conv_stride = list(conv_stride)
@@ -151,12 +189,18 @@ class Wav2Vec2Config(PretrainedConfig):
self.intermediate_size = intermediate_size
self.hidden_act = hidden_act
self.num_attention_heads = num_attention_heads
self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.hidden_dropout = hidden_dropout
self.attention_dropout = attention_dropout
self.activation_dropout = activation_dropout
self.feat_proj_dropout = feat_proj_dropout
self.final_dropout = final_dropout
self.layerdrop = layerdrop
self.layer_norm_eps = layer_norm_eps
self.initializer_range = initializer_range
self.vocab_size = vocab_size
self.do_stable_layer_norm = do_stable_layer_norm
self.freeze_feat_extract_train = freeze_feat_extract_train
self.gradient_checkpointing = gradient_checkpointing
if (
(len(self.conv_stride) != self.num_feat_extract_layers)
@@ -169,3 +213,14 @@ class Wav2Vec2Config(PretrainedConfig):
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
)
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
self.apply_spec_augment = apply_spec_augment
self.mask_time_prob = mask_time_prob
self.mask_time_length = mask_time_length
self.mask_feature_prob = mask_feature_prob
self.mask_feature_length = mask_feature_length
# ctc loss
self.ctc_loss_reduction = ctc_loss_reduction
self.ctc_zero_infinity = ctc_zero_infinity

View File

@@ -20,26 +20,27 @@ import argparse
import fairseq
import torch
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging
logging.set_verbosity_info()
logger = logging.get_logger(__name__)
MAPPING = {
"post_extract_proj": "wav2vec2.feature_projection.projection",
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
"post_extract_proj": "feature_projection.projection",
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
"fc2": "encoder.layers.*.feed_forward.output_dense",
"final_layer_norm": "encoder.layers.*.final_layer_norm",
"encoder.layer_norm": "encoder.layer_norm",
"w2v_model.layer_norm": "feature_projection.layer_norm",
"w2v_encoder.proj": "lm_head",
"mask_emb": "masked_spec_embed",
}
@@ -47,7 +48,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
for attribute in key.split("."):
hf_pointer = getattr(hf_pointer, attribute)
hf_shape = getattr(hf_pointer, weight_type).shape
if weight_type is not None:
hf_shape = getattr(hf_pointer, weight_type).shape
else:
hf_shape = hf_pointer.shape
assert (
hf_shape == value.shape
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
@@ -59,26 +64,32 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
hf_pointer.weight_v.data = value
elif weight_type == "bias":
hf_pointer.bias.data = value
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
else:
hf_pointer.data = value
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
def recursively_load_weights(fairseq_model, hf_model):
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
unused_weights = []
fairseq_dict = fairseq_model.state_dict()
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
for name, value in fairseq_dict.items():
is_used = False
if "conv_layers" in name:
load_conv_layer(
name,
value,
hf_model.wav2vec2.feature_extractor,
feature_extractor,
unused_weights,
hf_model.config.feat_extract_norm == "group",
)
is_used = True
else:
for key, mapped_key in MAPPING.items():
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
if key in name:
is_used = True
if "*" in mapped_key:
@@ -92,6 +103,8 @@ def recursively_load_weights(fairseq_model, hf_model):
weight_type = "weight"
elif "bias" in name:
weight_type = "bias"
else:
weight_type = None
set_recursively(hf_model, mapped_key, value, name, weight_type)
continue
if not is_used:
@@ -137,18 +150,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
@torch.no_grad()
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
def convert_wav2vec2_checkpoint(
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
):
"""
Copy/paste/tweak model's weights to transformers design.
"""
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
if config_path is not None:
config = Wav2Vec2Config.from_pretrained(config_path)
else:
config = Wav2Vec2Config()
if is_finetuned:
hf_wav2vec = Wav2Vec2ForCTC(config)
else:
hf_wav2vec = Wav2Vec2Model(config)
if is_finetuned:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
)
model = model[0].eval()
recursively_load_weights(model, hf_wav2vec)
recursively_load_weights(model, hf_wav2vec, is_finetuned)
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
@@ -158,5 +185,11 @@ if __name__ == "__main__":
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
parser.add_argument(
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
)
args = parser.parse_args()
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)
convert_wav2vec2_checkpoint(
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
)

View File

@@ -14,10 +14,10 @@
# limitations under the License.
""" PyTorch Wav2Vec2 model. """
import warnings
from typing import Optional, Tuple
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
@@ -44,6 +44,77 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
]
def _compute_mask_indices(
shape: Tuple[int, int],
mask_prob: float,
mask_length: int,
attention_mask: Optional[torch.Tensor] = None,
min_masks: int = 0,
) -> np.ndarray:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_length: size of the mask
min_masks: minimum number of masked spans
Adapted from `fairseq's data_utils.py
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
"""
bsz, all_sz = shape
mask = np.full((bsz, all_sz), False)
all_num_mask = int(
# add a random number for probabilistic rounding
mask_prob * all_sz / float(mask_length)
+ np.random.rand()
)
all_num_mask = max(min_masks, all_num_mask)
mask_idcs = []
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
for i in range(bsz):
if padding_mask is not None:
sz = all_sz - padding_mask[i].long().sum().item()
num_mask = int(
# add a random number for probabilistic rounding
mask_prob * sz / float(mask_length)
+ np.random.rand()
)
num_mask = max(min_masks, num_mask)
else:
sz = all_sz
num_mask = all_num_mask
lengths = np.full(num_mask, mask_length)
if sum(lengths) == 0:
lengths[0] = min(mask_length, sz - 1)
min_len = min(lengths)
if sz - min_len <= num_mask:
min_len = sz - num_mask - 1
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
min_len = min([len(m) for m in mask_idcs])
for i, mask_idc in enumerate(mask_idcs):
if len(mask_idc) > min_len:
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
mask[i, mask_idc] = True
return mask
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
def __init__(self, config, layer_id=0):
super().__init__()
@@ -57,12 +128,10 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Module):
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@@ -80,13 +149,11 @@ class Wav2Vec2LayerNormConvLayer(nn.Module):
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
self.activation = ACT2FN[config.feat_extract_activation]
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = hidden_states.transpose(-2, -1)
hidden_states = self.layer_norm(hidden_states)
@@ -109,14 +176,12 @@ class Wav2Vec2GroupNormConvLayer(nn.Module):
stride=config.conv_stride[layer_id],
bias=config.conv_bias,
)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.activation = ACT2FN[config.feat_extract_activation]
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
def forward(self, hidden_states):
hidden_states = self.conv(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.activation(hidden_states)
return hidden_states
@@ -178,6 +243,10 @@ class Wav2Vec2FeatureExtractor(nn.Module):
)
self.conv_layers = nn.ModuleList(conv_layers)
def _freeze_parameters(self):
for param in self.parameters():
param.requires_grad = False
def forward(self, input_values):
hidden_states = input_values[:, None]
for conv_layer in self.conv_layers:
@@ -191,7 +260,7 @@ class Wav2Vec2FeatureProjection(nn.Module):
super().__init__()
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
self.dropout = nn.Dropout(config.feat_extract_dropout)
self.dropout = nn.Dropout(config.feat_proj_dropout)
def forward(self, hidden_states):
hidden_states = self.layer_norm(hidden_states)
@@ -346,7 +415,7 @@ class Wav2Vec2Attention(nn.Module):
class Wav2Vec2FeedForward(nn.Module):
def __init__(self, config):
super().__init__()
self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob)
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
if isinstance(config.hidden_act, str):
@@ -355,7 +424,7 @@ class Wav2Vec2FeedForward(nn.Module):
self.intermediate_act_fn = config.hidden_act
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
self.output_dropout = nn.Dropout(config.hidden_dropout)
def forward(self, hidden_states):
hidden_states = self.intermediate_dense(hidden_states)
@@ -381,10 +450,10 @@ class Wav2Vec2EncoderLayer(nn.Module):
self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.hidden_dropout_prob,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -401,7 +470,12 @@ class Wav2Vec2EncoderLayer(nn.Module):
hidden_states = hidden_states + self.feed_forward(hidden_states)
hidden_states = self.final_layer_norm(hidden_states)
return hidden_states, attn_weights
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
@@ -410,10 +484,10 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
self.attention = Wav2Vec2Attention(
embed_dim=config.hidden_size,
num_heads=config.num_attention_heads,
dropout=config.hidden_dropout_prob,
dropout=config.attention_dropout,
is_decoder=False,
)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.feed_forward = Wav2Vec2FeedForward(config)
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
@@ -428,7 +502,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
hidden_states = attn_residual + hidden_states
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
return hidden_states, attn_weights
outputs = (hidden_states,)
if output_attentions:
outputs += (attn_weights,)
return outputs
class Wav2Vec2Encoder(nn.Module):
@@ -437,8 +516,7 @@ class Wav2Vec2Encoder(nn.Module):
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# IMPORTANT: the param for dropout is probs wrong
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
def forward(
@@ -471,12 +549,32 @@ class Wav2Vec2Encoder(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -496,8 +594,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
self.config = config
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
# IMPORTANT: the param for dropout is probs wrong
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.hidden_dropout)
self.layers = nn.ModuleList(
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
)
@@ -531,12 +628,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
hidden_states, attn_weights = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
dropout_probability = np.random.uniform(0, 1)
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
layer_outputs = (None, None)
else:
if getattr(self.config, "gradient_checkpointing", False) and self.training:
# create gradient checkpointing function
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
hidden_states,
attention_mask,
)
else:
layer_outputs = layer(
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
)
hidden_states = layer_outputs[0]
if output_attentions:
all_self_attentions = all_self_attentions + (attn_weights,)
all_self_attentions = all_self_attentions + (layer_outputs[1],)
hidden_states = self.layer_norm(hidden_states)
@@ -584,7 +701,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
def _conv_out_length(input_length, kernel_size, stride):
# 1D convolutional layer output length formula taken
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
return torch.floor((input_length - kernel_size) / stride + 1)
return (input_length - kernel_size) // stride + 1
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
@@ -659,6 +776,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
self.feature_projection = Wav2Vec2FeatureProjection(config)
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
if config.do_stable_layer_norm:
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
else:
@@ -726,6 +845,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
hidden_states = self.feature_projection(hidden_states)
if self.config.apply_spec_augment and self.training:
batch_size, sequence_length, hidden_size = hidden_states.size()
# apply SpecAugment along time axis
if self.config.mask_time_prob > 0:
mask_time_indices = _compute_mask_indices(
(batch_size, sequence_length),
self.config.mask_time_prob,
self.config.mask_time_length,
attention_mask=attention_mask,
min_masks=2,
)
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
# apply SpecAugment along feature axis
if self.config.mask_feature_prob > 0:
mask_feature_indices = _compute_mask_indices(
(batch_size, hidden_size),
self.config.mask_feature_prob,
self.config.mask_feature_length,
)
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
encoder_outputs = self.encoder(
hidden_states,
attention_mask=attention_mask,
@@ -756,7 +899,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.final_dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
@@ -773,7 +916,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
labels=None,
):
r"""
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training
Returns:
@@ -831,11 +974,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
super().__init__(config)
self.wav2vec2 = Wav2Vec2Model(config)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.dropout = nn.Dropout(config.final_dropout)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.init_weights()
def freeze_feature_extractor(self):
"""
Calling this function will disable the gradient computation for the feature extractor so that its parameter
will not be updated during training.
"""
self.wav2vec2.feature_extractor._freeze_parameters()
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
def forward(
@@ -848,8 +998,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
labels=None,
):
r"""
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
TODO(PVP): Fill out when adding training
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`):
Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
config.vocab_size - 1]``.
Returns:
@@ -873,9 +1026,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
>>> logits = model(input_values).logits
>>> predicted_ids = torch.argmax(logits, dim=-1)
>>> transcription = processor.decode(predicted_ids[0])
>>> # compute loss
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
>>> # wrap processor as target processor to encode labels
>>> with processor.as_target_processor():
>>> labels = processor(transcription, return_tensors="pt").input_ids
>>> loss = model(input_values, labels=labels).loss
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
@@ -893,8 +1055,38 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
logits = self.lm_head(hidden_states)
loss = None
if labels is not None:
# retrieve loss input_lengths from attention_mask
attention_mask = (
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
)
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
# assuming that padded tokens are filled with -100
# when not being attended to
labels_mask = labels >= 0
target_lengths = labels_mask.sum(-1)
flattened_targets = labels.masked_select(labels_mask)
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
with torch.backends.cudnn.flags(enabled=False):
loss = F.ctc_loss(
log_probs,
flattened_targets,
input_lengths,
target_lengths,
blank=self.config.pad_token_id,
reduction=self.config.ctc_loss_reduction,
zero_infinity=self.config.ctc_zero_infinity,
)
if not return_dict:
output = (logits,) + outputs[1:]
return output
return ((loss,) + output) if loss is not None else output
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
return CausalLMOutput(
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
)

View File

@@ -115,6 +115,16 @@ class Wav2Vec2Processor:
"""
return self.current_processor(*args, **kwargs)
def pad(self, *args, **kwargs):
"""
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the
above two methods for more information.
"""
return self.current_processor.pad(*args, **kwargs)
def batch_decode(self, *args, **kwargs):
"""
This method forwards all its arguments to Wav2Vec2CTCTokenizer's