From 700a748fe6f0ed62185710f20e1c78e083edc14b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 17 Nov 2021 14:38:56 +0100 Subject: [PATCH] [Wav2Vec2] Add New Wav2Vec2 Translation (#14392) * add new wav2vec2 translation * correct * up * add tests * correct end copy * correct more * up * correct unispeech sat * finish * finalize * finish * up --- ...rt_wav2vec2_seq2seq_original_to_pytorch.py | 353 ++++++++++++++++++ .../modeling_speech_encoder_decoder.py | 6 +- .../models/unispeech/modeling_unispeech.py | 3 +- .../unispeech_sat/modeling_unispeech_sat.py | 3 +- .../models/wav2vec2/configuration_wav2vec2.py | 25 ++ .../models/wav2vec2/modeling_wav2vec2.py | 81 +++- tests/test_modeling_wav2vec2.py | 46 +++ ..._pipelines_automatic_speech_recognition.py | 38 ++ 8 files changed, 544 insertions(+), 11 deletions(-) create mode 100644 src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py diff --git a/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py new file mode 100644 index 0000000000..3c25ab706f --- /dev/null +++ b/src/transformers/models/speech_encoder_decoder/convert_mbart_wav2vec2_seq2seq_original_to_pytorch.py @@ -0,0 +1,353 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Wav2Vec2 checkpoint.""" + + +import argparse + +import fairseq +import torch +from torch import nn + +from transformers import ( + MBart50Tokenizer, + MBartConfig, + MBartForCausalLM, + SpeechEncoderDecoderConfig, + SpeechEncoderDecoderModel, + Wav2Vec2Config, + Wav2Vec2FeatureExtractor, + Wav2Vec2Model, + logging, +) + + +logging.set_verbosity_info() +logger = logging.get_logger(__name__) + +MAPPING = { + "post_extract_proj": "feature_projection.projection", + "encoder.pos_conv.0": "encoder.pos_conv_embed.conv", + "self_attn.k_proj": "encoder.layers.*.attention.k_proj", + "self_attn.v_proj": "encoder.layers.*.attention.v_proj", + "self_attn.q_proj": "encoder.layers.*.attention.q_proj", + "self_attn.out_proj": "encoder.layers.*.attention.out_proj", + "self_attn_layer_norm": "encoder.layers.*.layer_norm", + "fc1": "encoder.layers.*.feed_forward.intermediate_dense", + "fc2": "encoder.layers.*.feed_forward.output_dense", + "final_layer_norm": "encoder.layers.*.final_layer_norm", + "encoder.layer_norm": "encoder.layer_norm", + "w2v_model.layer_norm": "feature_projection.layer_norm", + "quantizer.weight_proj": "quantizer.weight_proj", + "quantizer.vars": "quantizer.codevectors", + "project_q": "project_q", + "final_proj": "project_hid", + "w2v_encoder.proj": "lm_head", + "mask_emb": "masked_spec_embed", +} +TOP_LEVEL_KEYS = [ + "lm_head", + "quantizer.weight_proj", + "quantizer.codevectors", + "project_q", + "project_hid", +] + + +def set_recursively(hf_pointer, key, value, full_name, weight_type): + for attribute in key.split("."): + hf_pointer = getattr(hf_pointer, attribute) + + if weight_type is not None: + hf_shape = getattr(hf_pointer, weight_type).shape + else: + hf_shape = hf_pointer.shape + + assert ( + hf_shape == value.shape + ), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}" + + if weight_type == "weight": + hf_pointer.weight.data = value + elif weight_type == "weight_g": + hf_pointer.weight_g.data = value + elif weight_type == "weight_v": + hf_pointer.weight_v.data = value + elif weight_type == "bias": + hf_pointer.bias.data = value + else: + hf_pointer.data = value + + logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.") + + +def recursively_load_weights_wav2vec2(fairseq_model, hf_model): + unused_weights = [] + fairseq_dict = fairseq_model.state_dict() + + feature_extractor = hf_model.feature_extractor + adapter = hf_model.adapter + + for name, value in fairseq_dict.items(): + is_used = False + if "conv_layers" in name: + load_conv_layer( + name, + value, + feature_extractor, + unused_weights, + hf_model.config.feat_extract_norm == "group", + ) + is_used = True + elif any(x in name for x in ["adaptor", "w2v_encoder.proj.", "w2v_proj_ln."]): + load_adapter(name, value, adapter, unused_weights) + is_used = True + else: + for key, mapped_key in MAPPING.items(): + if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]: + is_used = True + if "*" in mapped_key: + layer_index = name.split(key)[0].split(".")[-2] + mapped_key = mapped_key.replace("*", layer_index) + if "weight_g" in name: + weight_type = "weight_g" + elif "weight_v" in name: + weight_type = "weight_v" + elif "bias" in name: + weight_type = "bias" + elif "weight" in name: + weight_type = "weight" + else: + weight_type = None + set_recursively(hf_model, mapped_key, value, name, weight_type) + continue + if not is_used: + unused_weights.append(name) + + logger.warning(f"Unused weights: {unused_weights}") + + +def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm): + name = full_name.split("conv_layers.")[-1] + items = name.split(".") + layer_id = int(items[0]) + type_id = int(items[1]) + + if type_id == 0: + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.bias.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].conv.weight.data = value + logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.") + elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm): + if "bias" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape + ), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found." + feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value + logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def load_adapter(full_name, value, adapter, unused_weights): + name = full_name.split("adaptor.")[-1] + items = name.split(".") + + if items[1].isdigit(): + layer_id = int(items[1]) + else: + layer_id = None + + if "adaptor" not in full_name: + if "proj_ln" in full_name: + # has to be layer norm + if "bias" in name: + assert ( + value.shape == adapter.proj_layer_norm.bias.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.bias.data.shape} was found." + adapter.proj_layer_norm.bias.data = value + logger.info(f"Adapter proj layer norm bias was initialized from {full_name}.") + if "weight" in name: + assert ( + value.shape == adapter.proj_layer_norm.weight.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.weight.data.shape} was found." + adapter.proj_layer_norm.weight.data = value + else: + # has to be projection layer + if "bias" in name: + assert ( + value.shape == adapter.proj.bias.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.proj.bias.data.shape} was found." + adapter.proj.bias.data = value + logger.info(f"Adapter proj layer bias was initialized from {full_name}.") + if "weight" in name: + assert ( + value.shape == adapter.proj.weight.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.proj.weight.data.shape} was found." + adapter.proj.weight.data = value + logger.info(f"Adapter proj layer weight was initialized from {full_name}.") + elif isinstance(layer_id, int): + if "bias" in name: + assert ( + value.shape == adapter.layers[layer_id].conv.bias.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.bias.data.shape} was found." + adapter.layers[layer_id].conv.bias.data = value + logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.") + elif "weight" in name: + assert ( + value.shape == adapter.layers[layer_id].conv.weight.data.shape + ), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.weight.data.shape} was found." + adapter.layers[layer_id].conv.weight.data = value + logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.") + else: + unused_weights.append(full_name) + + +def make_linear_from_emb(emb): + vocab_size, emb_size = emb.weight.shape + lin_layer = nn.Linear(vocab_size, emb_size, bias=False) + lin_layer.weight.data = emb.weight.data + return lin_layer + + +@torch.no_grad() +def convert_wav2vec2_checkpoint( + checkpoint_path, + pytorch_dump_folder_path, + dict_path, + config_yaml_path, + encoder_config_path, + decoder_config_path, + add_adapter, + adapter_kernel_size, + adapter_stride, + decoder_start_token_id, + encoder_output_dim, +): + """ + Copy/paste/tweak model's weights to transformers design. + """ + # load configs + encoder_config = Wav2Vec2Config.from_pretrained( + encoder_config_path, + add_adapter=True, + adapter_stride=adapter_stride, + adapter_kernel_size=adapter_kernel_size, + use_auth_token=True, + output_hidden_size=encoder_output_dim, + ) + decoder_config = MBartConfig.from_pretrained(decoder_config_path) + + # load model + model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task( + [checkpoint_path], + arg_overrides={ + "config_yaml": config_yaml_path, + "data": "/".join(dict_path.split("/")[:-1]), + "w2v_path": checkpoint_path, + "load_pretrained_decoder_from": None, + }, + ) + model = model[0].eval() + + # load feature extractor + feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_config_path, use_auth_token=True) + + # set weights for wav2vec2 encoder + hf_encoder = Wav2Vec2Model(encoder_config) + + recursively_load_weights_wav2vec2(model.encoder, hf_encoder) + + # load decoder weights + hf_decoder = MBartForCausalLM(decoder_config) + missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False) + logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}") + logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}") + + hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder) + hf_wav2vec.config.tie_word_embeddings = False + + tokenizer = MBart50Tokenizer(dict_path) + tokenizer.save_pretrained(pytorch_dump_folder_path) + + config = hf_wav2vec.config.to_dict() + config["pad_token_id"] = tokenizer.pad_token_id + config["bos_token_id"] = tokenizer.bos_token_id + config["eos_token_id"] = tokenizer.eos_token_id + config["tokenizer_class"] = "mbart50" + config["feature_extractor_type"] = "wav2vec2" + + config["decoder_start_token_id"] = tokenizer.eos_token_id + config["forced_bos_token_id"] = 250004 + config["forced_eos_token_id"] = tokenizer.eos_token_id + + hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config) + + hf_wav2vec.save_pretrained(pytorch_dump_folder_path) + feature_extractor.save_pretrained(pytorch_dump_folder_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.") + parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint") + parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model") + parser.add_argument("--config_yaml_path", default=None, type=str, help="Path to yaml file of fine-tuned model") + parser.add_argument( + "--encoder_config_path", + default="facebook/wav2vec2-xls-r-1b", + type=str, + help="Path to hf encoder wav2vec2 checkpoint config", + ) + parser.add_argument( + "--decoder_config_path", + default="facebook/mbart-large-50-one-to-many-mmt", + type=str, + help="Path to hf decoder checkpoint config", + ) + parser.add_argument("--add_adapter", default=True, type=bool, help="whethere to add model adapter layers") + parser.add_argument("--adapter_stride", default=2, type=int, help="stride of adapter layers") + parser.add_argument("--adapter_kernel_size", default=3, type=int, help="kernel size of adapter layers") + parser.add_argument("--encoder_output_dim", default=1024, type=int, help="encoder output dim") + parser.add_argument("--start_token_id", default=250004, type=int, help="`decoder_start_token_id` of model config") + + args = parser.parse_args() + convert_wav2vec2_checkpoint( + args.checkpoint_path, + args.pytorch_dump_folder_path, + args.dict_path, + args.config_yaml_path, + encoder_config_path=args.encoder_config_path, + decoder_config_path=args.decoder_config_path, + add_adapter=args.add_adapter, + adapter_kernel_size=args.adapter_kernel_size, + adapter_stride=args.adapter_stride, + decoder_start_token_id=args.start_token_id, + encoder_output_dim=args.encoder_output_dim, + ) diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py index ec0b459c34..98a41b3c55 100644 --- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py @@ -223,7 +223,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel): self.encoder.config = self.config.encoder self.decoder.config = self.config.decoder - if self.encoder.config.hidden_size != self.decoder.config.hidden_size: + # get encoder output hidden size + self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size) + if self.encoder_output_dim != self.decoder.config.hidden_size: # encoder outputs might need to be projected to different dimension for decoder self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size) @@ -471,7 +473,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel): encoder_hidden_states = encoder_outputs[0] # project encoder_hidden_states - if self.encoder.config.hidden_size != self.decoder.config.hidden_size: + if self.encoder_output_dim != self.decoder.config.hidden_size: encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states) # compute correct encoder attention mask diff --git a/src/transformers/models/unispeech/modeling_unispeech.py b/src/transformers/models/unispeech/modeling_unispeech.py index e1a7bc4e42..3f700ee153 100755 --- a/src/transformers/models/unispeech/modeling_unispeech.py +++ b/src/transformers/models/unispeech/modeling_unispeech.py @@ -892,7 +892,6 @@ class UniSpeechGumbelVectorQuantizer(nn.Module): return codevectors, perplexity -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeech, wav2vec2->unispeech class UniSpeechPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1032,7 +1031,6 @@ UNISPEECH_INPUTS_DOCSTRING = r""" "The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.", UNISPEECH_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH class UniSpeechModel(UniSpeechPreTrainedModel): def __init__(self, config: UniSpeechConfig): super().__init__(config) @@ -1049,6 +1047,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel): self.init_weights() + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, diff --git a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py index 3c6c77e14b..ae28492064 100755 --- a/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py +++ b/src/transformers/models/unispeech_sat/modeling_unispeech_sat.py @@ -893,7 +893,6 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module): return codevectors, perplexity -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat class UniSpeechSatPreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -1033,7 +1032,6 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r""" "The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.", UNISPEECH_SAT_START_DOCSTRING, ) -# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT class UniSpeechSatModel(UniSpeechSatPreTrainedModel): def __init__(self, config: UniSpeechSatConfig): super().__init__(config) @@ -1050,6 +1048,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel): self.init_weights() + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states def _mask_hidden_states( self, hidden_states: torch.FloatTensor, diff --git a/src/transformers/models/wav2vec2/configuration_wav2vec2.py b/src/transformers/models/wav2vec2/configuration_wav2vec2.py index 44e156f2b2..69e2106c11 100644 --- a/src/transformers/models/wav2vec2/configuration_wav2vec2.py +++ b/src/transformers/models/wav2vec2/configuration_wav2vec2.py @@ -140,6 +140,19 @@ class Wav2Vec2Config(PretrainedConfig): instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`. classifier_proj_size (:obj:`int`, `optional`, defaults to 256): Dimensionality of the projection before token mean-pooling for classification. + add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`): + Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for + warm-starting Wav2Vec2 for SpeechEncoderDecoder models. + adapter_kernel_size (:obj:`int`, `optional`, defaults to 3): + Kernel size of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``. + adapter_stride (:obj:`int`, `optional`, defaults to 2): + Stride of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``. + num_adapter_layers (:obj:`int`, `optional`, defaults to 3): + Number of convolutional layers that should be used in the adapter network. Only relevant if ``add_adapter + is True``. + output_hidden_size (:obj:`int`, `optional`): + Dimensionality of the encoder output layer. If not defined, this defaults to `hidden-size`. Only relevant + if ``add_adapter is True``. Example:: @@ -201,6 +214,11 @@ class Wav2Vec2Config(PretrainedConfig): pad_token_id=0, bos_token_id=1, eos_token_id=2, + add_adapter=False, + adapter_kernel_size=3, + adapter_stride=2, + num_adapter_layers=3, + output_hidden_size=None, **kwargs ): super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id) @@ -263,3 +281,10 @@ class Wav2Vec2Config(PretrainedConfig): # ctc loss self.ctc_loss_reduction = ctc_loss_reduction self.ctc_zero_infinity = ctc_zero_infinity + + # adapter + self.add_adapter = add_adapter + self.adapter_kernel_size = adapter_kernel_size + self.adapter_stride = adapter_stride + self.num_adapter_layers = num_adapter_layers + self.output_hidden_size = output_hidden_size or hidden_size diff --git a/src/transformers/models/wav2vec2/modeling_wav2vec2.py b/src/transformers/models/wav2vec2/modeling_wav2vec2.py index bbe363861e..0bb456620b 100755 --- a/src/transformers/models/wav2vec2/modeling_wav2vec2.py +++ b/src/transformers/models/wav2vec2/modeling_wav2vec2.py @@ -935,6 +935,55 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module): return codevectors, perplexity +class Wav2Vec2Adapter(nn.Module): + def __init__(self, config): + super().__init__() + + # feature dim might need to be down-projected + if config.output_hidden_size != config.hidden_size: + self.proj = nn.Linear(config.hidden_size, config.output_hidden_size) + self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size) + else: + self.proj = self.proj_layer_norm = None + + self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers)) + self.layerdrop = config.layerdrop + + def forward(self, hidden_states): + # down project hidden_states if necessary + if self.proj is not None and self.proj_layer_norm is not None: + hidden_states = self.proj(hidden_states) + hidden_states = self.proj_layer_norm(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + + for layer in self.layers: + layerdrop_prob = np.random.random() + if not self.training or (layerdrop_prob > self.layerdrop): + hidden_states = layer(hidden_states) + + hidden_states = hidden_states.transpose(1, 2) + return hidden_states + + +class Wav2Vec2AdapterLayer(nn.Module): + def __init__(self, config): + super().__init__() + self.conv = nn.Conv1d( + config.output_hidden_size, + 2 * config.output_hidden_size, + config.adapter_kernel_size, + stride=config.adapter_stride, + padding=1, + ) + + def forward(self, hidden_states): + hidden_states = self.conv(hidden_states) + hidden_states = nn.functional.glu(hidden_states, dim=1) + + return hidden_states + + class Wav2Vec2PreTrainedModel(PreTrainedModel): """ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained @@ -979,11 +1028,15 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0])) nn.init.uniform_(module.bias, a=-k, b=k) - def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]): + def _get_feat_extract_output_lengths( + self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None + ): """ Computes the output length of the convolutional layers """ + add_adapter = self.config.add_adapter if add_adapter is None else add_adapter + def _conv_out_length(input_length, kernel_size, stride): # 1D convolutional layer output length formula taken # from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html @@ -992,13 +1045,22 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel): for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride): input_lengths = _conv_out_length(input_lengths, kernel_size, stride) + if add_adapter: + for _ in range(self.config.num_adapter_layers): + input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride) + return input_lengths - def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): + def _get_feature_vector_attention_mask( + self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None + ): # Effectively attention_mask.sum(-1), but not inplace to be able to run # on inference mode. non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1] - output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long) + + output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter) + output_lengths = output_lengths.to(torch.long) + batch_size = attention_mask.shape[0] attention_mask = torch.zeros( @@ -1088,6 +1150,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): else: self.encoder = Wav2Vec2Encoder(config) + self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None + self.init_weights() def _mask_hidden_states( @@ -1163,7 +1227,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): if attention_mask is not None: # compute reduced attention_mask corresponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) hidden_states, extract_features = self.feature_projection(extract_features) hidden_states = self._mask_hidden_states( @@ -1180,6 +1246,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel): hidden_states = encoder_outputs[0] + if self.adapter is not None: + hidden_states = self.adapter(hidden_states) + if not return_dict: return (hidden_states, extract_features) + encoder_outputs[1:] @@ -1328,7 +1397,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel): if attention_mask is not None: # compute reduced attention_mask correponding to feature vectors - attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) + attention_mask = self._get_feature_vector_attention_mask( + extract_features.shape[1], attention_mask, add_adapter=False + ) quantized_features, codevector_perplexity = self.quantizer( extract_features, mask_time_indices=mask_time_indices diff --git a/tests/test_modeling_wav2vec2.py b/tests/test_modeling_wav2vec2.py index 5ac40fe98e..c689b05f25 100644 --- a/tests/test_modeling_wav2vec2.py +++ b/tests/test_modeling_wav2vec2.py @@ -82,6 +82,8 @@ class Wav2Vec2ModelTester: mask_time_length=2, vocab_size=32, do_stable_layer_norm=False, + num_adapter_layers=1, + adapter_stride=2, scope=None, ): self.parent = parent @@ -107,6 +109,8 @@ class Wav2Vec2ModelTester: self.initializer_range = initializer_range self.vocab_size = vocab_size self.do_stable_layer_norm = do_stable_layer_norm + self.num_adapter_layers = num_adapter_layers + self.adapter_stride = adapter_stride self.mask_time_prob = mask_time_prob self.mask_time_length = mask_time_length self.scope = scope @@ -117,6 +121,8 @@ class Wav2Vec2ModelTester: self.output_seq_length = int(math.ceil(output_seq_length)) self.encoder_seq_length = self.output_seq_length + self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1 + def prepare_config_and_inputs(self): input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size) attention_mask = random_attention_mask([self.batch_size, self.seq_length]) @@ -148,6 +154,8 @@ class Wav2Vec2ModelTester: hidden_act=self.hidden_act, initializer_range=self.initializer_range, vocab_size=self.vocab_size, + num_adapter_layers=self.num_adapter_layers, + adapter_stride=self.adapter_stride, ) def create_and_check_model(self, config, input_values, attention_mask): @@ -159,6 +167,28 @@ class Wav2Vec2ModelTester: result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size) ) + def create_and_check_model_with_adapter(self, config, input_values, attention_mask): + config.add_adapter = True + model = Wav2Vec2Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size) + ) + + def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask): + config.add_adapter = True + config.output_hidden_size = 8 + model = Wav2Vec2Model(config=config) + model.to(torch_device) + model.eval() + result = model(input_values, attention_mask=attention_mask) + self.parent.assertEqual( + result.last_hidden_state.shape, + (self.batch_size, self.adapter_output_seq_length, config.output_hidden_size), + ) + def create_and_check_batch_inference(self, config, input_values, *args): # test does not pass for models making use of `group_norm` # check: https://github.com/pytorch/fairseq/issues/3227 @@ -332,6 +362,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_with_adapter(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_adapter(*config_and_inputs) + + def test_model_with_adapter_proj_dim(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) + def test_ctc_loss_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.check_ctc_loss(*config_and_inputs) @@ -544,6 +582,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_model(*config_and_inputs) + def test_model_with_adapter(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_adapter(*config_and_inputs) + + def test_model_with_adapter_proj_dim(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs) + def test_batched_inference(self): config_and_inputs = self.model_tester.prepare_config_and_inputs() self.model_tester.create_and_check_batch_inference(*config_and_inputs) diff --git a/tests/test_pipelines_automatic_speech_recognition.py b/tests/test_pipelines_automatic_speech_recognition.py index a50512887d..ecdf447752 100644 --- a/tests/test_pipelines_automatic_speech_recognition.py +++ b/tests/test_pipelines_automatic_speech_recognition.py @@ -203,3 +203,41 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel data = f.read() output = asr(data) self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."}) + + @slow + @require_torch + @require_torchaudio + @require_datasets + def test_xls_r_to_en(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/wav2vec2-xls-r-1b-21-to-en", + feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en", + framework="pt", + ) + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."}) + + @slow + @require_torch + @require_torchaudio + @require_datasets + def test_xls_r_from_en(self): + speech_recognizer = pipeline( + task="automatic-speech-recognition", + model="facebook/wav2vec2-xls-r-1b-en-to-15", + feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15", + framework="pt", + ) + + from datasets import load_dataset + + ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id") + filename = ds[40]["file"] + output = speech_recognizer(filename) + self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})