Add Fine-Tuning for Wav2Vec2 (#10145)
* add encode labels function to tokenizer * start adding finetuning * init dropout * upload * correct convert script * apply changes * fix second typo * make first dummy training run * adapt convert script * push confg for comparison * remove conf * finish training * adapt data collator * add research folder * update according to fairseq feedback * some minor corrections * refactor masking indices a bit * some minor changes * clean tokenizer * finish clean-up * remove previous logic * update run script * correct training * finish changes * finish model * correct bug * fix training a bit more * add some tests * finish gradient checkpointing * finish example * correct gradient checkpointing * improve tokenization method * revert changes in tokenizer * revert general change * adapt fine-tuning * update * save intermediate test * Update README.md * finish finetuning * delete conversion script * Update src/transformers/models/wav2vec2/configuration_wav2vec2.py * Update src/transformers/models/wav2vec2/processing_wav2vec2.py Co-authored-by: Lysandre Debut <lysandre@huggingface.co> * finish wav2vec2 script * finish wav2vec2 fine-tuning * finalize test * correct test * adapt tests * finish * remove test file Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
3c733f3208
commit
0234de8418
@@ -92,6 +92,33 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
Whether do apply `stable` layer norm architecture of the Transformer encoder. ``do_stable_layer_norm is
|
||||
True`` corresponds to applying layer norm before the attention layer, whereas ``do_stable_layer_norm is
|
||||
False`` corresponds to applying layer norm after the attention layer.
|
||||
freeze_feat_extract_train (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to freeze the weights of the feature extractor when training.
|
||||
apply_spec_augment (:obj:`bool`, `optional`, defaults to :obj:`True`):
|
||||
Whether to apply *SpecAugment* data augmentation to the outputs of the feature extractor. For reference see
|
||||
`SpecAugment: A Simple Data Augmentation Method for Automatic Speech Recognition
|
||||
<https://arxiv.org/abs/1904.08779>`__.
|
||||
mask_time_prob (:obj:`float`, `optional`, defaults to 0.05):
|
||||
Propability of each feature vector along the time axis to be chosen as the start of the vector span to be
|
||||
masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_time_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the time axis.
|
||||
mask_feature_prob (:obj:`float`, `optional`, defaults to 0.0):
|
||||
Propability of each feature vector along the feature axis to be chosen as the start of the vector span to
|
||||
be masked. Approximately ``mask_time_prob * hidden_size // mask_time_length`` feature vectors will be
|
||||
masked along the time axis. This is only relevant if ``apply_spec_augment is True``.
|
||||
mask_feature_length (:obj:`int`, `optional`, defaults to 10):
|
||||
Length of vector span along the feature axis.
|
||||
ctc_loss_reduction (:obj:`str`, `optional`, defaults to :obj:`"sum"`):
|
||||
Specifies the reduction to apply to the output of ``torch.nn.CTCLoss``. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
ctc_zero_infinity (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
|
||||
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
|
||||
instance of :class:`~transformers.Wav2Vec2ForCTC`.
|
||||
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -116,12 +143,15 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
num_attention_heads=12,
|
||||
intermediate_size=3072,
|
||||
hidden_act="gelu",
|
||||
hidden_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||
attention_probs_dropout_prob=0.1, # TODO(PVP) this is most likely not correctly set yet - correct when adding train
|
||||
hidden_dropout=0.1,
|
||||
activation_dropout=0.1,
|
||||
attention_dropout=0.1,
|
||||
feat_proj_dropout=0.1,
|
||||
final_dropout=0.1,
|
||||
layerdrop=0.1,
|
||||
initializer_range=0.02,
|
||||
layer_norm_eps=1e-5,
|
||||
feat_extract_norm="group",
|
||||
feat_extract_dropout=0.0,
|
||||
feat_extract_activation="gelu",
|
||||
conv_dim=(512, 512, 512, 512, 512, 512, 512),
|
||||
conv_stride=(5, 2, 2, 2, 2, 2, 2),
|
||||
@@ -130,6 +160,15 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
num_conv_pos_embeddings=128,
|
||||
num_conv_pos_embedding_groups=16,
|
||||
do_stable_layer_norm=False,
|
||||
freeze_feat_extract_train=True,
|
||||
apply_spec_augment=True,
|
||||
mask_time_prob=0.05,
|
||||
mask_time_length=10,
|
||||
mask_feature_prob=0.0,
|
||||
mask_feature_length=10,
|
||||
ctc_loss_reduction="sum",
|
||||
ctc_zero_infinity=False,
|
||||
gradient_checkpointing=False,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
@@ -138,7 +177,6 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
||||
self.hidden_size = hidden_size
|
||||
self.feat_extract_norm = feat_extract_norm
|
||||
self.feat_extract_dropout = feat_extract_dropout
|
||||
self.feat_extract_activation = feat_extract_activation
|
||||
self.conv_dim = list(conv_dim)
|
||||
self.conv_stride = list(conv_stride)
|
||||
@@ -151,12 +189,18 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_dropout_prob = hidden_dropout_prob
|
||||
self.attention_probs_dropout_prob = attention_probs_dropout_prob
|
||||
self.hidden_dropout = hidden_dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.feat_proj_dropout = feat_proj_dropout
|
||||
self.final_dropout = final_dropout
|
||||
self.layerdrop = layerdrop
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.freeze_feat_extract_train = freeze_feat_extract_train
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
if (
|
||||
(len(self.conv_stride) != self.num_feat_extract_layers)
|
||||
@@ -169,3 +213,14 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
f"but is `len(config.conv_dim) = {len(self.conv_dim)}`, `len(config.conv_stride)"
|
||||
f"= {len(self.conv_stride)}`, `len(config.conv_kernel) = {len(self.conv_kernel)}`."
|
||||
)
|
||||
|
||||
# fine-tuning config parameters for SpecAugment: https://arxiv.org/abs/1904.08779
|
||||
self.apply_spec_augment = apply_spec_augment
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.mask_feature_prob = mask_feature_prob
|
||||
self.mask_feature_length = mask_feature_length
|
||||
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
@@ -20,26 +20,27 @@ import argparse
|
||||
import fairseq
|
||||
import torch
|
||||
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, logging
|
||||
from transformers import Wav2Vec2Config, Wav2Vec2ForCTC, Wav2Vec2Model, logging
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "wav2vec2.feature_projection.projection",
|
||||
"encoder.pos_conv.0": "wav2vec2.encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "wav2vec2.encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "wav2vec2.encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "wav2vec2.encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "wav2vec2.encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "wav2vec2.encoder.layers.*.layer_norm",
|
||||
"fc1": "wav2vec2.encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "wav2vec2.encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "wav2vec2.encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "wav2vec2.encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "wav2vec2.feature_projection.layer_norm",
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
||||
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
|
||||
|
||||
@@ -47,7 +48,11 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
assert (
|
||||
hf_shape == value.shape
|
||||
), f"Shape of hf {key + '.' + weight_type} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
@@ -59,26 +64,32 @@ def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
logger.info(f"{key + '.' + weight_type} was initialized from {full_name}.")
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights(fairseq_model, hf_model):
|
||||
def recursively_load_weights(fairseq_model, hf_model, is_finetuned):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.wav2vec2.feature_extractor if is_finetuned else hf_model.feature_extractor
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
if "conv_layers" in name:
|
||||
load_conv_layer(
|
||||
name,
|
||||
value,
|
||||
hf_model.wav2vec2.feature_extractor,
|
||||
feature_extractor,
|
||||
unused_weights,
|
||||
hf_model.config.feat_extract_norm == "group",
|
||||
)
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
mapped_key = "wav2vec2." + mapped_key if (is_finetuned and mapped_key != "lm_head") else mapped_key
|
||||
if key in name:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
@@ -92,6 +103,8 @@ def recursively_load_weights(fairseq_model, hf_model):
|
||||
weight_type = "weight"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
continue
|
||||
if not is_used:
|
||||
@@ -137,18 +150,32 @@ def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_gro
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_wav2vec2_checkpoint(checkpoint_path, pytorch_dump_folder_path, dict_path=None):
|
||||
def convert_wav2vec2_checkpoint(
|
||||
checkpoint_path, pytorch_dump_folder_path, config_path=None, dict_path=None, is_finetuned=True
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
hf_wav2vec = Wav2Vec2ForCTC(Wav2Vec2Config())
|
||||
if config_path is not None:
|
||||
config = Wav2Vec2Config.from_pretrained(config_path)
|
||||
else:
|
||||
config = Wav2Vec2Config()
|
||||
|
||||
if is_finetuned:
|
||||
hf_wav2vec = Wav2Vec2ForCTC(config)
|
||||
else:
|
||||
hf_wav2vec = Wav2Vec2Model(config)
|
||||
|
||||
if is_finetuned:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||
)
|
||||
else:
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
|
||||
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path], arg_overrides={"data": dict_path}
|
||||
)
|
||||
model = model[0].eval()
|
||||
|
||||
recursively_load_weights(model, hf_wav2vec)
|
||||
recursively_load_weights(model, hf_wav2vec, is_finetuned)
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
@@ -158,5 +185,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||
parser.add_argument("--config_path", default=None, type=str, help="Path to hf config.json of model to convert")
|
||||
parser.add_argument(
|
||||
"--not_finetuned", action="store_true", help="Whether the model to convert is a fine-tuned model or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
convert_wav2vec2_checkpoint(args.checkpoint_path, args.pytorch_dump_folder_path, args.dict_path)
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path, args.pytorch_dump_folder_path, args.config_path, args.dict_path, not args.not_finetuned
|
||||
)
|
||||
|
||||
@@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
""" PyTorch Wav2Vec2 model. """
|
||||
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
@@ -44,6 +44,77 @@ WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
]
|
||||
|
||||
|
||||
def _compute_mask_indices(
|
||||
shape: Tuple[int, int],
|
||||
mask_prob: float,
|
||||
mask_length: int,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
min_masks: int = 0,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Computes random mask spans for a given shape
|
||||
|
||||
Args:
|
||||
shape: the the shape for which to compute masks.
|
||||
should be of size 2 where first element is batch size and 2nd is timesteps
|
||||
attention_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
|
||||
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
|
||||
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
|
||||
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
|
||||
mask_length: size of the mask
|
||||
min_masks: minimum number of masked spans
|
||||
|
||||
Adapted from `fairseq's data_utils.py
|
||||
<https://github.com/pytorch/fairseq/blob/e0788f7007a8473a76db573985031f3c94201e79/fairseq/data/data_utils.py#L376>`__.
|
||||
"""
|
||||
bsz, all_sz = shape
|
||||
mask = np.full((bsz, all_sz), False)
|
||||
|
||||
all_num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * all_sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
|
||||
all_num_mask = max(min_masks, all_num_mask)
|
||||
|
||||
mask_idcs = []
|
||||
padding_mask = attention_mask.ne(1) if attention_mask is not None else None
|
||||
for i in range(bsz):
|
||||
if padding_mask is not None:
|
||||
sz = all_sz - padding_mask[i].long().sum().item()
|
||||
num_mask = int(
|
||||
# add a random number for probabilistic rounding
|
||||
mask_prob * sz / float(mask_length)
|
||||
+ np.random.rand()
|
||||
)
|
||||
num_mask = max(min_masks, num_mask)
|
||||
else:
|
||||
sz = all_sz
|
||||
num_mask = all_num_mask
|
||||
|
||||
lengths = np.full(num_mask, mask_length)
|
||||
|
||||
if sum(lengths) == 0:
|
||||
lengths[0] = min(mask_length, sz - 1)
|
||||
|
||||
min_len = min(lengths)
|
||||
if sz - min_len <= num_mask:
|
||||
min_len = sz - num_mask - 1
|
||||
|
||||
mask_idc = np.random.choice(sz - min_len, num_mask, replace=False)
|
||||
mask_idc = np.asarray([mask_idc[j] + offset for j in range(len(mask_idc)) for offset in range(lengths[j])])
|
||||
mask_idcs.append(np.unique(mask_idc[mask_idc < sz]))
|
||||
|
||||
min_len = min([len(m) for m in mask_idcs])
|
||||
for i, mask_idc in enumerate(mask_idcs):
|
||||
if len(mask_idc) > min_len:
|
||||
mask_idc = np.random.choice(mask_idc, min_len, replace=False)
|
||||
mask[i, mask_idc] = True
|
||||
|
||||
return mask
|
||||
|
||||
|
||||
class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
def __init__(self, config, layer_id=0):
|
||||
super().__init__()
|
||||
@@ -57,12 +128,10 @@ class Wav2Vec2NoLayerNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
@@ -80,13 +149,11 @@ class Wav2Vec2LayerNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.layer_norm = nn.LayerNorm(self.out_conv_dim, elementwise_affine=True)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-2, -1)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -109,14 +176,12 @@ class Wav2Vec2GroupNormConvLayer(nn.Module):
|
||||
stride=config.conv_stride[layer_id],
|
||||
bias=config.conv_bias,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.activation = ACT2FN[config.feat_extract_activation]
|
||||
|
||||
self.layer_norm = nn.GroupNorm(num_groups=self.out_conv_dim, num_channels=self.out_conv_dim, affine=True)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.activation(hidden_states)
|
||||
return hidden_states
|
||||
@@ -178,6 +243,10 @@ class Wav2Vec2FeatureExtractor(nn.Module):
|
||||
)
|
||||
self.conv_layers = nn.ModuleList(conv_layers)
|
||||
|
||||
def _freeze_parameters(self):
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, input_values):
|
||||
hidden_states = input_values[:, None]
|
||||
for conv_layer in self.conv_layers:
|
||||
@@ -191,7 +260,7 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
||||
super().__init__()
|
||||
self.layer_norm = nn.LayerNorm(config.conv_dim[-1], eps=config.layer_norm_eps)
|
||||
self.projection = nn.Linear(config.conv_dim[-1], config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
self.dropout = nn.Dropout(config.feat_proj_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -346,7 +415,7 @@ class Wav2Vec2Attention(nn.Module):
|
||||
class Wav2Vec2FeedForward(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.intermediate_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.intermediate_dropout = nn.Dropout(config.activation_dropout)
|
||||
|
||||
self.intermediate_dense = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
if isinstance(config.hidden_act, str):
|
||||
@@ -355,7 +424,7 @@ class Wav2Vec2FeedForward(nn.Module):
|
||||
self.intermediate_act_fn = config.hidden_act
|
||||
|
||||
self.output_dense = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.output_dropout = nn.Dropout(config.hidden_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.intermediate_dense(hidden_states)
|
||||
@@ -381,10 +450,10 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||
self.attention = Wav2Vec2Attention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.hidden_dropout_prob,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=False,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -401,7 +470,12 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||
hidden_states = hidden_states + self.feed_forward(hidden_states)
|
||||
hidden_states = self.final_layer_norm(hidden_states)
|
||||
|
||||
return hidden_states, attn_weights
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
@@ -410,10 +484,10 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
self.attention = Wav2Vec2Attention(
|
||||
embed_dim=config.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
dropout=config.hidden_dropout_prob,
|
||||
dropout=config.attention_dropout,
|
||||
is_decoder=False,
|
||||
)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
@@ -428,7 +502,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
hidden_states = attn_residual + hidden_states
|
||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||
|
||||
return hidden_states, attn_weights
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (attn_weights,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class Wav2Vec2Encoder(nn.Module):
|
||||
@@ -437,8 +516,7 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
self.config = config
|
||||
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# IMPORTANT: the param for dropout is probs wrong
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList([Wav2Vec2EncoderLayer(config) for _ in range(config.num_hidden_layers)])
|
||||
|
||||
def forward(
|
||||
@@ -471,12 +549,32 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
# create gradient checkpointing function
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
@@ -496,8 +594,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
self.config = config
|
||||
self.pos_conv_embed = Wav2Vec2PositionalConvEmbedding(config)
|
||||
self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
# IMPORTANT: the param for dropout is probs wrong
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout)
|
||||
self.layers = nn.ModuleList(
|
||||
[Wav2Vec2EncoderLayerStableLayerNorm(config) for _ in range(config.num_hidden_layers)]
|
||||
)
|
||||
@@ -531,12 +628,32 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
|
||||
dropout_probability = np.random.uniform(0, 1)
|
||||
if self.training and (dropout_probability < self.config.layerdrop): # skip the layer
|
||||
layer_outputs = (None, None)
|
||||
else:
|
||||
if getattr(self.config, "gradient_checkpointing", False) and self.training:
|
||||
# create gradient checkpointing function
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs, output_attentions)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
layer_outputs = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
all_self_attentions = all_self_attentions + (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
|
||||
@@ -584,7 +701,7 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
return (input_length - kernel_size) // stride + 1
|
||||
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
@@ -659,6 +776,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor(config)
|
||||
self.feature_projection = Wav2Vec2FeatureProjection(config)
|
||||
|
||||
self.masked_spec_embed = nn.Parameter(torch.FloatTensor(config.hidden_size).uniform_())
|
||||
|
||||
if config.do_stable_layer_norm:
|
||||
self.encoder = Wav2Vec2EncoderStableLayerNorm(config)
|
||||
else:
|
||||
@@ -726,6 +845,30 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)
|
||||
|
||||
if self.config.apply_spec_augment and self.training:
|
||||
batch_size, sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
# apply SpecAugment along time axis
|
||||
if self.config.mask_time_prob > 0:
|
||||
mask_time_indices = _compute_mask_indices(
|
||||
(batch_size, sequence_length),
|
||||
self.config.mask_time_prob,
|
||||
self.config.mask_time_length,
|
||||
attention_mask=attention_mask,
|
||||
min_masks=2,
|
||||
)
|
||||
hidden_states[torch.from_numpy(mask_time_indices)] = self.masked_spec_embed.to(hidden_states.dtype)
|
||||
|
||||
# apply SpecAugment along feature axis
|
||||
if self.config.mask_feature_prob > 0:
|
||||
mask_feature_indices = _compute_mask_indices(
|
||||
(batch_size, hidden_size),
|
||||
self.config.mask_feature_prob,
|
||||
self.config.mask_feature_length,
|
||||
)
|
||||
mask_feature_indices = torch.from_numpy(mask_feature_indices).to(hidden_states.device)
|
||||
hidden_states[mask_feature_indices[:, None].expand(-1, sequence_length, -1)] = 0
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
@@ -756,7 +899,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
@@ -773,7 +916,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
TODO(PVP): Fill out when adding training
|
||||
|
||||
Returns:
|
||||
@@ -831,11 +974,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
super().__init__(config)
|
||||
|
||||
self.wav2vec2 = Wav2Vec2Model(config)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.dropout = nn.Dropout(config.final_dropout)
|
||||
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def freeze_feature_extractor(self):
|
||||
"""
|
||||
Calling this function will disable the gradient computation for the feature extractor so that its parameter
|
||||
will not be updated during training.
|
||||
"""
|
||||
self.wav2vec2.feature_extractor._freeze_parameters()
|
||||
|
||||
@add_start_docstrings_to_model_forward(WAV_2_VEC_2_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=BaseModelOutput, config_class=_CONFIG_FOR_DOC)
|
||||
def forward(
|
||||
@@ -848,8 +998,11 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
labels=None,
|
||||
):
|
||||
r"""
|
||||
labels (:obj:`Float.LongTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
||||
TODO(PVP): Fill out when adding training
|
||||
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, target_length)`, `optional`):
|
||||
Labels for connectionist temporal classification. Note that ``target_length`` has to be smaller or equal to
|
||||
the sequence length of the output logits. Indices are selected in ``[-100, 0, ..., config.vocab_size -
|
||||
1]``. All labels set to ``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ...,
|
||||
config.vocab_size - 1]``.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -873,9 +1026,18 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
>>> input_values = processor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
|
||||
>>> logits = model(input_values).logits
|
||||
|
||||
>>> predicted_ids = torch.argmax(logits, dim=-1)
|
||||
|
||||
>>> transcription = processor.decode(predicted_ids[0])
|
||||
|
||||
>>> # compute loss
|
||||
>>> target_transcription = "A MAN SAID TO THE UNIVERSE SIR I EXIST"
|
||||
|
||||
>>> # wrap processor as target processor to encode labels
|
||||
>>> with processor.as_target_processor():
|
||||
>>> labels = processor(transcription, return_tensors="pt").input_ids
|
||||
|
||||
>>> loss = model(input_values, labels=labels).loss
|
||||
"""
|
||||
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
@@ -893,8 +1055,38 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
|
||||
# retrieve loss input_lengths from attention_mask
|
||||
attention_mask = (
|
||||
attention_mask if attention_mask is not None else torch.ones_like(input_values, dtype=torch.long)
|
||||
)
|
||||
input_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
|
||||
# assuming that padded tokens are filled with -100
|
||||
# when not being attended to
|
||||
labels_mask = labels >= 0
|
||||
target_lengths = labels_mask.sum(-1)
|
||||
flattened_targets = labels.masked_select(labels_mask)
|
||||
|
||||
log_probs = F.log_softmax(logits, dim=-1).transpose(0, 1)
|
||||
|
||||
with torch.backends.cudnn.flags(enabled=False):
|
||||
loss = F.ctc_loss(
|
||||
log_probs,
|
||||
flattened_targets,
|
||||
input_lengths,
|
||||
target_lengths,
|
||||
blank=self.config.pad_token_id,
|
||||
reduction=self.config.ctc_loss_reduction,
|
||||
zero_infinity=self.config.ctc_zero_infinity,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return output
|
||||
return ((loss,) + output) if loss is not None else output
|
||||
|
||||
return CausalLMOutput(logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions)
|
||||
return CausalLMOutput(
|
||||
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
||||
)
|
||||
|
||||
@@ -115,6 +115,16 @@ class Wav2Vec2Processor:
|
||||
"""
|
||||
return self.current_processor(*args, **kwargs)
|
||||
|
||||
def pad(self, *args, **kwargs):
|
||||
"""
|
||||
When used in normal mode, this method forwards all its arguments to Wav2Vec2FeatureExtractor's
|
||||
:meth:`~transformers.Wav2Vec2FeatureExtractor.pad` and returns its output. If used in the context
|
||||
:meth:`~transformers.Wav2Vec2Processor.as_target_processor` this method forwards all its arguments to
|
||||
Wav2Vec2CTCTokenizer's :meth:`~transformers.Wav2Vec2CTCTokenizer.pad`. Please refer to the docstring of the
|
||||
above two methods for more information.
|
||||
"""
|
||||
return self.current_processor.pad(*args, **kwargs)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Wav2Vec2CTCTokenizer's
|
||||
|
||||
@@ -509,11 +509,18 @@ class Trainer:
|
||||
|
||||
# Build the sampler.
|
||||
if self.args.group_by_length:
|
||||
model_input_name = self.tokenizer.model_input_names[0] if self.tokenizer is not None else None
|
||||
if num_processes <= 1:
|
||||
return LengthGroupedSampler(self.train_dataset, self.args.train_batch_size)
|
||||
return LengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, model_input_name=model_input_name
|
||||
)
|
||||
else:
|
||||
return DistributedLengthGroupedSampler(
|
||||
self.train_dataset, self.args.train_batch_size, num_replicas=num_processes, rank=process_index
|
||||
self.train_dataset,
|
||||
self.args.train_batch_size,
|
||||
num_replicas=num_processes,
|
||||
rank=process_index,
|
||||
model_input_name=model_input_name,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -452,16 +452,23 @@ class LengthGroupedSampler(Sampler):
|
||||
keeping a bit of randomness.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset: Dataset, batch_size: int, lengths: Optional[List[int]] = None):
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Dataset,
|
||||
batch_size: int,
|
||||
lengths: Optional[List[int]] = None,
|
||||
model_input_name: Optional[str] = None,
|
||||
):
|
||||
self.dataset = dataset
|
||||
self.batch_size = batch_size
|
||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
if not isinstance(dataset[0], dict) or model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
f"'{self.model_input_name}' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __len__(self):
|
||||
@@ -487,6 +494,7 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
seed: int = 0,
|
||||
drop_last: bool = False,
|
||||
lengths: Optional[List[int]] = None,
|
||||
model_input_name: Optional[str] = None,
|
||||
):
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
@@ -513,14 +521,15 @@ class DistributedLengthGroupedSampler(DistributedSampler):
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.seed = seed
|
||||
self.model_input_name = model_input_name if model_input_name is not None else "input_ids"
|
||||
|
||||
if lengths is None:
|
||||
if not isinstance(dataset[0], dict) or "input_ids" not in dataset[0]:
|
||||
if not isinstance(dataset[0], dict) or self.model_input_name not in dataset[0]:
|
||||
raise ValueError(
|
||||
"Can only automatically infer lengths for datasets whose items are dictionaries with an "
|
||||
"'input_ids' key."
|
||||
f"'{self.model_input_name}' key."
|
||||
)
|
||||
lengths = [len(feature["input_ids"]) for feature in dataset]
|
||||
lengths = [len(feature[self.model_input_name]) for feature in dataset]
|
||||
self.lengths = lengths
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
|
||||
Reference in New Issue
Block a user