[Wav2Vec2] Add New Wav2Vec2 Translation (#14392)
* add new wav2vec2 translation * correct * up * add tests * correct end copy * correct more * up * correct unispeech sat * finish * finalize * finish * up
This commit is contained in:
committed by
GitHub
parent
b567510cff
commit
700a748fe6
@@ -0,0 +1,353 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Convert Wav2Vec2 checkpoint."""
|
||||
|
||||
|
||||
import argparse
|
||||
|
||||
import fairseq
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from transformers import (
|
||||
MBart50Tokenizer,
|
||||
MBartConfig,
|
||||
MBartForCausalLM,
|
||||
SpeechEncoderDecoderConfig,
|
||||
SpeechEncoderDecoderModel,
|
||||
Wav2Vec2Config,
|
||||
Wav2Vec2FeatureExtractor,
|
||||
Wav2Vec2Model,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logging.set_verbosity_info()
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
MAPPING = {
|
||||
"post_extract_proj": "feature_projection.projection",
|
||||
"encoder.pos_conv.0": "encoder.pos_conv_embed.conv",
|
||||
"self_attn.k_proj": "encoder.layers.*.attention.k_proj",
|
||||
"self_attn.v_proj": "encoder.layers.*.attention.v_proj",
|
||||
"self_attn.q_proj": "encoder.layers.*.attention.q_proj",
|
||||
"self_attn.out_proj": "encoder.layers.*.attention.out_proj",
|
||||
"self_attn_layer_norm": "encoder.layers.*.layer_norm",
|
||||
"fc1": "encoder.layers.*.feed_forward.intermediate_dense",
|
||||
"fc2": "encoder.layers.*.feed_forward.output_dense",
|
||||
"final_layer_norm": "encoder.layers.*.final_layer_norm",
|
||||
"encoder.layer_norm": "encoder.layer_norm",
|
||||
"w2v_model.layer_norm": "feature_projection.layer_norm",
|
||||
"quantizer.weight_proj": "quantizer.weight_proj",
|
||||
"quantizer.vars": "quantizer.codevectors",
|
||||
"project_q": "project_q",
|
||||
"final_proj": "project_hid",
|
||||
"w2v_encoder.proj": "lm_head",
|
||||
"mask_emb": "masked_spec_embed",
|
||||
}
|
||||
TOP_LEVEL_KEYS = [
|
||||
"lm_head",
|
||||
"quantizer.weight_proj",
|
||||
"quantizer.codevectors",
|
||||
"project_q",
|
||||
"project_hid",
|
||||
]
|
||||
|
||||
|
||||
def set_recursively(hf_pointer, key, value, full_name, weight_type):
|
||||
for attribute in key.split("."):
|
||||
hf_pointer = getattr(hf_pointer, attribute)
|
||||
|
||||
if weight_type is not None:
|
||||
hf_shape = getattr(hf_pointer, weight_type).shape
|
||||
else:
|
||||
hf_shape = hf_pointer.shape
|
||||
|
||||
assert (
|
||||
hf_shape == value.shape
|
||||
), f"Shape of hf {key + '.' + weight_type if weight_type is not None else ''} is {hf_shape}, but should be {value.shape} for {full_name}"
|
||||
|
||||
if weight_type == "weight":
|
||||
hf_pointer.weight.data = value
|
||||
elif weight_type == "weight_g":
|
||||
hf_pointer.weight_g.data = value
|
||||
elif weight_type == "weight_v":
|
||||
hf_pointer.weight_v.data = value
|
||||
elif weight_type == "bias":
|
||||
hf_pointer.bias.data = value
|
||||
else:
|
||||
hf_pointer.data = value
|
||||
|
||||
logger.info(f"{key + '.' + weight_type if weight_type is not None else ''} was initialized from {full_name}.")
|
||||
|
||||
|
||||
def recursively_load_weights_wav2vec2(fairseq_model, hf_model):
|
||||
unused_weights = []
|
||||
fairseq_dict = fairseq_model.state_dict()
|
||||
|
||||
feature_extractor = hf_model.feature_extractor
|
||||
adapter = hf_model.adapter
|
||||
|
||||
for name, value in fairseq_dict.items():
|
||||
is_used = False
|
||||
if "conv_layers" in name:
|
||||
load_conv_layer(
|
||||
name,
|
||||
value,
|
||||
feature_extractor,
|
||||
unused_weights,
|
||||
hf_model.config.feat_extract_norm == "group",
|
||||
)
|
||||
is_used = True
|
||||
elif any(x in name for x in ["adaptor", "w2v_encoder.proj.", "w2v_proj_ln."]):
|
||||
load_adapter(name, value, adapter, unused_weights)
|
||||
is_used = True
|
||||
else:
|
||||
for key, mapped_key in MAPPING.items():
|
||||
if key in name or key.split("w2v_model.")[-1] == name.split(".")[0]:
|
||||
is_used = True
|
||||
if "*" in mapped_key:
|
||||
layer_index = name.split(key)[0].split(".")[-2]
|
||||
mapped_key = mapped_key.replace("*", layer_index)
|
||||
if "weight_g" in name:
|
||||
weight_type = "weight_g"
|
||||
elif "weight_v" in name:
|
||||
weight_type = "weight_v"
|
||||
elif "bias" in name:
|
||||
weight_type = "bias"
|
||||
elif "weight" in name:
|
||||
weight_type = "weight"
|
||||
else:
|
||||
weight_type = None
|
||||
set_recursively(hf_model, mapped_key, value, name, weight_type)
|
||||
continue
|
||||
if not is_used:
|
||||
unused_weights.append(name)
|
||||
|
||||
logger.warning(f"Unused weights: {unused_weights}")
|
||||
|
||||
|
||||
def load_conv_layer(full_name, value, feature_extractor, unused_weights, use_group_norm):
|
||||
name = full_name.split("conv_layers.")[-1]
|
||||
items = name.split(".")
|
||||
layer_id = int(items[0])
|
||||
type_id = int(items[1])
|
||||
|
||||
if type_id == 0:
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.bias.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].conv.bias.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].conv.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor.conv_layers[layer_id].conv.weight.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].conv.weight.data = value
|
||||
logger.info(f"Feat extract conv layer {layer_id} was initialized from {full_name}.")
|
||||
elif (type_id == 2 and not use_group_norm) or (type_id == 2 and layer_id == 0 and use_group_norm):
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.bias.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.bias.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == feature_extractor.conv_layers[layer_id].layer_norm.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {feature_extractor[layer_id].layer_norm.weight.data.shape} was found."
|
||||
feature_extractor.conv_layers[layer_id].layer_norm.weight.data = value
|
||||
logger.info(f"Feat extract layer norm weight of layer {layer_id} was initialized from {full_name}.")
|
||||
else:
|
||||
unused_weights.append(full_name)
|
||||
|
||||
|
||||
def load_adapter(full_name, value, adapter, unused_weights):
|
||||
name = full_name.split("adaptor.")[-1]
|
||||
items = name.split(".")
|
||||
|
||||
if items[1].isdigit():
|
||||
layer_id = int(items[1])
|
||||
else:
|
||||
layer_id = None
|
||||
|
||||
if "adaptor" not in full_name:
|
||||
if "proj_ln" in full_name:
|
||||
# has to be layer norm
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == adapter.proj_layer_norm.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.bias.data.shape} was found."
|
||||
adapter.proj_layer_norm.bias.data = value
|
||||
logger.info(f"Adapter proj layer norm bias was initialized from {full_name}.")
|
||||
if "weight" in name:
|
||||
assert (
|
||||
value.shape == adapter.proj_layer_norm.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.proj_layer_norm.weight.data.shape} was found."
|
||||
adapter.proj_layer_norm.weight.data = value
|
||||
else:
|
||||
# has to be projection layer
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == adapter.proj.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.proj.bias.data.shape} was found."
|
||||
adapter.proj.bias.data = value
|
||||
logger.info(f"Adapter proj layer bias was initialized from {full_name}.")
|
||||
if "weight" in name:
|
||||
assert (
|
||||
value.shape == adapter.proj.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.proj.weight.data.shape} was found."
|
||||
adapter.proj.weight.data = value
|
||||
logger.info(f"Adapter proj layer weight was initialized from {full_name}.")
|
||||
elif isinstance(layer_id, int):
|
||||
if "bias" in name:
|
||||
assert (
|
||||
value.shape == adapter.layers[layer_id].conv.bias.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.bias.data.shape} was found."
|
||||
adapter.layers[layer_id].conv.bias.data = value
|
||||
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
|
||||
elif "weight" in name:
|
||||
assert (
|
||||
value.shape == adapter.layers[layer_id].conv.weight.data.shape
|
||||
), f"{full_name} has size {value.shape}, but {adapter.layers[layer_id].conv.weight.data.shape} was found."
|
||||
adapter.layers[layer_id].conv.weight.data = value
|
||||
logger.info(f"Adapter layer {layer_id} bias was initialized from {full_name}.")
|
||||
else:
|
||||
unused_weights.append(full_name)
|
||||
|
||||
|
||||
def make_linear_from_emb(emb):
|
||||
vocab_size, emb_size = emb.weight.shape
|
||||
lin_layer = nn.Linear(vocab_size, emb_size, bias=False)
|
||||
lin_layer.weight.data = emb.weight.data
|
||||
return lin_layer
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def convert_wav2vec2_checkpoint(
|
||||
checkpoint_path,
|
||||
pytorch_dump_folder_path,
|
||||
dict_path,
|
||||
config_yaml_path,
|
||||
encoder_config_path,
|
||||
decoder_config_path,
|
||||
add_adapter,
|
||||
adapter_kernel_size,
|
||||
adapter_stride,
|
||||
decoder_start_token_id,
|
||||
encoder_output_dim,
|
||||
):
|
||||
"""
|
||||
Copy/paste/tweak model's weights to transformers design.
|
||||
"""
|
||||
# load configs
|
||||
encoder_config = Wav2Vec2Config.from_pretrained(
|
||||
encoder_config_path,
|
||||
add_adapter=True,
|
||||
adapter_stride=adapter_stride,
|
||||
adapter_kernel_size=adapter_kernel_size,
|
||||
use_auth_token=True,
|
||||
output_hidden_size=encoder_output_dim,
|
||||
)
|
||||
decoder_config = MBartConfig.from_pretrained(decoder_config_path)
|
||||
|
||||
# load model
|
||||
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
|
||||
[checkpoint_path],
|
||||
arg_overrides={
|
||||
"config_yaml": config_yaml_path,
|
||||
"data": "/".join(dict_path.split("/")[:-1]),
|
||||
"w2v_path": checkpoint_path,
|
||||
"load_pretrained_decoder_from": None,
|
||||
},
|
||||
)
|
||||
model = model[0].eval()
|
||||
|
||||
# load feature extractor
|
||||
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(encoder_config_path, use_auth_token=True)
|
||||
|
||||
# set weights for wav2vec2 encoder
|
||||
hf_encoder = Wav2Vec2Model(encoder_config)
|
||||
|
||||
recursively_load_weights_wav2vec2(model.encoder, hf_encoder)
|
||||
|
||||
# load decoder weights
|
||||
hf_decoder = MBartForCausalLM(decoder_config)
|
||||
missing_keys, unexpected_keys = hf_decoder.model.decoder.load_state_dict(model.decoder.state_dict(), strict=False)
|
||||
logger.warning(f"The following keys are missing when loading the decoder weights: {missing_keys}")
|
||||
logger.warning(f"The following keys are unexpected when loading the decoder weights: {unexpected_keys}")
|
||||
|
||||
hf_wav2vec = SpeechEncoderDecoderModel(encoder=hf_encoder, decoder=hf_decoder)
|
||||
hf_wav2vec.config.tie_word_embeddings = False
|
||||
|
||||
tokenizer = MBart50Tokenizer(dict_path)
|
||||
tokenizer.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
config = hf_wav2vec.config.to_dict()
|
||||
config["pad_token_id"] = tokenizer.pad_token_id
|
||||
config["bos_token_id"] = tokenizer.bos_token_id
|
||||
config["eos_token_id"] = tokenizer.eos_token_id
|
||||
config["tokenizer_class"] = "mbart50"
|
||||
config["feature_extractor_type"] = "wav2vec2"
|
||||
|
||||
config["decoder_start_token_id"] = tokenizer.eos_token_id
|
||||
config["forced_bos_token_id"] = 250004
|
||||
config["forced_eos_token_id"] = tokenizer.eos_token_id
|
||||
|
||||
hf_wav2vec.config = SpeechEncoderDecoderConfig.from_dict(config)
|
||||
|
||||
hf_wav2vec.save_pretrained(pytorch_dump_folder_path)
|
||||
feature_extractor.save_pretrained(pytorch_dump_folder_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_dump_folder_path", default=None, type=str, help="Path to the output PyTorch model.")
|
||||
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to fairseq checkpoint")
|
||||
parser.add_argument("--dict_path", default=None, type=str, help="Path to dict of fine-tuned model")
|
||||
parser.add_argument("--config_yaml_path", default=None, type=str, help="Path to yaml file of fine-tuned model")
|
||||
parser.add_argument(
|
||||
"--encoder_config_path",
|
||||
default="facebook/wav2vec2-xls-r-1b",
|
||||
type=str,
|
||||
help="Path to hf encoder wav2vec2 checkpoint config",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--decoder_config_path",
|
||||
default="facebook/mbart-large-50-one-to-many-mmt",
|
||||
type=str,
|
||||
help="Path to hf decoder checkpoint config",
|
||||
)
|
||||
parser.add_argument("--add_adapter", default=True, type=bool, help="whethere to add model adapter layers")
|
||||
parser.add_argument("--adapter_stride", default=2, type=int, help="stride of adapter layers")
|
||||
parser.add_argument("--adapter_kernel_size", default=3, type=int, help="kernel size of adapter layers")
|
||||
parser.add_argument("--encoder_output_dim", default=1024, type=int, help="encoder output dim")
|
||||
parser.add_argument("--start_token_id", default=250004, type=int, help="`decoder_start_token_id` of model config")
|
||||
|
||||
args = parser.parse_args()
|
||||
convert_wav2vec2_checkpoint(
|
||||
args.checkpoint_path,
|
||||
args.pytorch_dump_folder_path,
|
||||
args.dict_path,
|
||||
args.config_yaml_path,
|
||||
encoder_config_path=args.encoder_config_path,
|
||||
decoder_config_path=args.decoder_config_path,
|
||||
add_adapter=args.add_adapter,
|
||||
adapter_kernel_size=args.adapter_kernel_size,
|
||||
adapter_stride=args.adapter_stride,
|
||||
decoder_start_token_id=args.start_token_id,
|
||||
encoder_output_dim=args.encoder_output_dim,
|
||||
)
|
||||
@@ -223,7 +223,9 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
self.encoder.config = self.config.encoder
|
||||
self.decoder.config = self.config.decoder
|
||||
|
||||
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
|
||||
# get encoder output hidden size
|
||||
self.encoder_output_dim = getattr(config.encoder, "output_hidden_size", config.encoder.hidden_size)
|
||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
||||
# encoder outputs might need to be projected to different dimension for decoder
|
||||
self.enc_to_dec_proj = nn.Linear(self.encoder.config.hidden_size, self.decoder.config.hidden_size)
|
||||
|
||||
@@ -471,7 +473,7 @@ class SpeechEncoderDecoderModel(PreTrainedModel):
|
||||
encoder_hidden_states = encoder_outputs[0]
|
||||
|
||||
# project encoder_hidden_states
|
||||
if self.encoder.config.hidden_size != self.decoder.config.hidden_size:
|
||||
if self.encoder_output_dim != self.decoder.config.hidden_size:
|
||||
encoder_hidden_states = self.enc_to_dec_proj(encoder_hidden_states)
|
||||
|
||||
# compute correct encoder attention mask
|
||||
|
||||
@@ -892,7 +892,6 @@ class UniSpeechGumbelVectorQuantizer(nn.Module):
|
||||
return codevectors, perplexity
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeech, wav2vec2->unispeech
|
||||
class UniSpeechPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@@ -1032,7 +1031,6 @@ UNISPEECH_INPUTS_DOCSTRING = r"""
|
||||
"The bare UniSpeech Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
UNISPEECH_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeech, wav2vec2->unispeech, WAV_2_VEC_2->UNISPEECH
|
||||
class UniSpeechModel(UniSpeechPreTrainedModel):
|
||||
def __init__(self, config: UniSpeechConfig):
|
||||
super().__init__(config)
|
||||
@@ -1049,6 +1047,7 @@ class UniSpeechModel(UniSpeechPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
||||
def _mask_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
|
||||
@@ -893,7 +893,6 @@ class UniSpeechSatGumbelVectorQuantizer(nn.Module):
|
||||
return codevectors, perplexity
|
||||
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2PreTrainedModel with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat
|
||||
class UniSpeechSatPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@@ -1033,7 +1032,6 @@ UNISPEECH_SAT_INPUTS_DOCSTRING = r"""
|
||||
"The bare UniSpeechSat Model transformer outputting raw hidden-states without any specific head on top.",
|
||||
UNISPEECH_SAT_START_DOCSTRING,
|
||||
)
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model with Wav2Vec2->UniSpeechSat, wav2vec2->unispeech_sat, WAV_2_VEC_2->UNISPEECH_SAT
|
||||
class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
|
||||
def __init__(self, config: UniSpeechSatConfig):
|
||||
super().__init__(config)
|
||||
@@ -1050,6 +1048,7 @@ class UniSpeechSatModel(UniSpeechSatPreTrainedModel):
|
||||
|
||||
self.init_weights()
|
||||
|
||||
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2Model._mask_hidden_states
|
||||
def _mask_hidden_states(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
|
||||
@@ -140,6 +140,19 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
instance of :class:`~transformers.Wav2Vec2ForSequenceClassification`.
|
||||
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
|
||||
Dimensionality of the projection before token mean-pooling for classification.
|
||||
add_adapter (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether a convolutional network should be stacked on top of the Wav2Vec2 Encoder. Can be very useful for
|
||||
warm-starting Wav2Vec2 for SpeechEncoderDecoder models.
|
||||
adapter_kernel_size (:obj:`int`, `optional`, defaults to 3):
|
||||
Kernel size of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``.
|
||||
adapter_stride (:obj:`int`, `optional`, defaults to 2):
|
||||
Stride of the convolutional layers in the adapter network. Only relevant if ``add_adapter is True``.
|
||||
num_adapter_layers (:obj:`int`, `optional`, defaults to 3):
|
||||
Number of convolutional layers that should be used in the adapter network. Only relevant if ``add_adapter
|
||||
is True``.
|
||||
output_hidden_size (:obj:`int`, `optional`):
|
||||
Dimensionality of the encoder output layer. If not defined, this defaults to `hidden-size`. Only relevant
|
||||
if ``add_adapter is True``.
|
||||
|
||||
Example::
|
||||
|
||||
@@ -201,6 +214,11 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
add_adapter=False,
|
||||
adapter_kernel_size=3,
|
||||
adapter_stride=2,
|
||||
num_adapter_layers=3,
|
||||
output_hidden_size=None,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(**kwargs, pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id)
|
||||
@@ -263,3 +281,10 @@ class Wav2Vec2Config(PretrainedConfig):
|
||||
# ctc loss
|
||||
self.ctc_loss_reduction = ctc_loss_reduction
|
||||
self.ctc_zero_infinity = ctc_zero_infinity
|
||||
|
||||
# adapter
|
||||
self.add_adapter = add_adapter
|
||||
self.adapter_kernel_size = adapter_kernel_size
|
||||
self.adapter_stride = adapter_stride
|
||||
self.num_adapter_layers = num_adapter_layers
|
||||
self.output_hidden_size = output_hidden_size or hidden_size
|
||||
|
||||
@@ -935,6 +935,55 @@ class Wav2Vec2GumbelVectorQuantizer(nn.Module):
|
||||
return codevectors, perplexity
|
||||
|
||||
|
||||
class Wav2Vec2Adapter(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
|
||||
# feature dim might need to be down-projected
|
||||
if config.output_hidden_size != config.hidden_size:
|
||||
self.proj = nn.Linear(config.hidden_size, config.output_hidden_size)
|
||||
self.proj_layer_norm = nn.LayerNorm(config.output_hidden_size)
|
||||
else:
|
||||
self.proj = self.proj_layer_norm = None
|
||||
|
||||
self.layers = nn.ModuleList(Wav2Vec2AdapterLayer(config) for _ in range(config.num_adapter_layers))
|
||||
self.layerdrop = config.layerdrop
|
||||
|
||||
def forward(self, hidden_states):
|
||||
# down project hidden_states if necessary
|
||||
if self.proj is not None and self.proj_layer_norm is not None:
|
||||
hidden_states = self.proj(hidden_states)
|
||||
hidden_states = self.proj_layer_norm(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
for layer in self.layers:
|
||||
layerdrop_prob = np.random.random()
|
||||
if not self.training or (layerdrop_prob > self.layerdrop):
|
||||
hidden_states = layer(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Wav2Vec2AdapterLayer(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv1d(
|
||||
config.output_hidden_size,
|
||||
2 * config.output_hidden_size,
|
||||
config.adapter_kernel_size,
|
||||
stride=config.adapter_stride,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv(hidden_states)
|
||||
hidden_states = nn.functional.glu(hidden_states, dim=1)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
||||
@@ -979,11 +1028,15 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
k = math.sqrt(module.groups / (module.in_channels * module.kernel_size[0]))
|
||||
nn.init.uniform_(module.bias, a=-k, b=k)
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: Union[torch.LongTensor, int]):
|
||||
def _get_feat_extract_output_lengths(
|
||||
self, input_lengths: Union[torch.LongTensor, int], add_adapter: Optional[bool] = None
|
||||
):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
add_adapter = self.config.add_adapter if add_adapter is None else add_adapter
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
@@ -992,13 +1045,22 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
|
||||
if add_adapter:
|
||||
for _ in range(self.config.num_adapter_layers):
|
||||
input_lengths = _conv_out_length(input_lengths, 1, self.config.adapter_stride)
|
||||
|
||||
return input_lengths
|
||||
|
||||
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
|
||||
def _get_feature_vector_attention_mask(
|
||||
self, feature_vector_length: int, attention_mask: torch.LongTensor, add_adapter=None
|
||||
):
|
||||
# Effectively attention_mask.sum(-1), but not inplace to be able to run
|
||||
# on inference mode.
|
||||
non_padded_lengths = attention_mask.cumsum(dim=-1)[:, -1]
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths).to(torch.long)
|
||||
|
||||
output_lengths = self._get_feat_extract_output_lengths(non_padded_lengths, add_adapter=add_adapter)
|
||||
output_lengths = output_lengths.to(torch.long)
|
||||
|
||||
batch_size = attention_mask.shape[0]
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
@@ -1088,6 +1150,8 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
else:
|
||||
self.encoder = Wav2Vec2Encoder(config)
|
||||
|
||||
self.adapter = Wav2Vec2Adapter(config) if config.add_adapter else None
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def _mask_hidden_states(
|
||||
@@ -1163,7 +1227,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask corresponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
hidden_states, extract_features = self.feature_projection(extract_features)
|
||||
hidden_states = self._mask_hidden_states(
|
||||
@@ -1180,6 +1246,9 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if self.adapter is not None:
|
||||
hidden_states = self.adapter(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (hidden_states, extract_features) + encoder_outputs[1:]
|
||||
|
||||
@@ -1328,7 +1397,9 @@ class Wav2Vec2ForPreTraining(Wav2Vec2PreTrainedModel):
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute reduced attention_mask correponding to feature vectors
|
||||
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
|
||||
attention_mask = self._get_feature_vector_attention_mask(
|
||||
extract_features.shape[1], attention_mask, add_adapter=False
|
||||
)
|
||||
|
||||
quantized_features, codevector_perplexity = self.quantizer(
|
||||
extract_features, mask_time_indices=mask_time_indices
|
||||
|
||||
@@ -82,6 +82,8 @@ class Wav2Vec2ModelTester:
|
||||
mask_time_length=2,
|
||||
vocab_size=32,
|
||||
do_stable_layer_norm=False,
|
||||
num_adapter_layers=1,
|
||||
adapter_stride=2,
|
||||
scope=None,
|
||||
):
|
||||
self.parent = parent
|
||||
@@ -107,6 +109,8 @@ class Wav2Vec2ModelTester:
|
||||
self.initializer_range = initializer_range
|
||||
self.vocab_size = vocab_size
|
||||
self.do_stable_layer_norm = do_stable_layer_norm
|
||||
self.num_adapter_layers = num_adapter_layers
|
||||
self.adapter_stride = adapter_stride
|
||||
self.mask_time_prob = mask_time_prob
|
||||
self.mask_time_length = mask_time_length
|
||||
self.scope = scope
|
||||
@@ -117,6 +121,8 @@ class Wav2Vec2ModelTester:
|
||||
self.output_seq_length = int(math.ceil(output_seq_length))
|
||||
self.encoder_seq_length = self.output_seq_length
|
||||
|
||||
self.adapter_output_seq_length = (self.output_seq_length - 1) // adapter_stride + 1
|
||||
|
||||
def prepare_config_and_inputs(self):
|
||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||
@@ -148,6 +154,8 @@ class Wav2Vec2ModelTester:
|
||||
hidden_act=self.hidden_act,
|
||||
initializer_range=self.initializer_range,
|
||||
vocab_size=self.vocab_size,
|
||||
num_adapter_layers=self.num_adapter_layers,
|
||||
adapter_stride=self.adapter_stride,
|
||||
)
|
||||
|
||||
def create_and_check_model(self, config, input_values, attention_mask):
|
||||
@@ -159,6 +167,28 @@ class Wav2Vec2ModelTester:
|
||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_model_with_adapter(self, config, input_values, attention_mask):
|
||||
config.add_adapter = True
|
||||
model = Wav2Vec2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape, (self.batch_size, self.adapter_output_seq_length, self.hidden_size)
|
||||
)
|
||||
|
||||
def create_and_check_model_with_adapter_proj_dim(self, config, input_values, attention_mask):
|
||||
config.add_adapter = True
|
||||
config.output_hidden_size = 8
|
||||
model = Wav2Vec2Model(config=config)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
result = model(input_values, attention_mask=attention_mask)
|
||||
self.parent.assertEqual(
|
||||
result.last_hidden_state.shape,
|
||||
(self.batch_size, self.adapter_output_seq_length, config.output_hidden_size),
|
||||
)
|
||||
|
||||
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||
# test does not pass for models making use of `group_norm`
|
||||
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||
@@ -332,6 +362,14 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter_proj_dim(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
def test_ctc_loss_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.check_ctc_loss(*config_and_inputs)
|
||||
@@ -544,6 +582,14 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter(*config_and_inputs)
|
||||
|
||||
def test_model_with_adapter_proj_dim(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_model_with_adapter_proj_dim(*config_and_inputs)
|
||||
|
||||
def test_batched_inference(self):
|
||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||
|
||||
@@ -203,3 +203,41 @@ class AutomaticSpeechRecognitionPipelineTests(unittest.TestCase, metaclass=Pipel
|
||||
data = f.read()
|
||||
output = asr(data)
|
||||
self.assertEqual(output, {"text": "Un uomo disse all'universo: \"Signore, io esisto."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_datasets
|
||||
def test_xls_r_to_en(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/wav2vec2-xls-r-1b-21-to-en",
|
||||
feature_extractor="facebook/wav2vec2-xls-r-1b-21-to-en",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "A man said to the universe: “Sir, I exist."})
|
||||
|
||||
@slow
|
||||
@require_torch
|
||||
@require_torchaudio
|
||||
@require_datasets
|
||||
def test_xls_r_from_en(self):
|
||||
speech_recognizer = pipeline(
|
||||
task="automatic-speech-recognition",
|
||||
model="facebook/wav2vec2-xls-r-1b-en-to-15",
|
||||
feature_extractor="facebook/wav2vec2-xls-r-1b-en-to-15",
|
||||
framework="pt",
|
||||
)
|
||||
|
||||
from datasets import load_dataset
|
||||
|
||||
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation").sort("id")
|
||||
filename = ds[40]["file"]
|
||||
output = speech_recognizer(filename)
|
||||
self.assertEqual(output, {"text": "Ein Mann sagte zu dem Universum, Sir, ich bin da."})
|
||||
|
||||
Reference in New Issue
Block a user