[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)
* save intermediate * finish batch the same as fairseq * add normalization * fix batched input * add better comment * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py * add nice docstring * add tokenizer tests * make all slow tests pass * finish PR * correct import
This commit is contained in:
committed by
GitHub
parent
2f3b5f4dcc
commit
495c157d6f
@@ -36,7 +36,10 @@ logger = logging.get_logger(__name__)
|
||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||
|
||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||
"facebook/wav2vec2-base-960h"
|
||||
"facebook/wav2vec2-base-960h",
|
||||
"facebook/wav2vec2-large-960h",
|
||||
"facebook/wav2vec2-large-960h-lv60",
|
||||
"facebook/wav2vec2-large-960h-lv60-self",
|
||||
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
||||
]
|
||||
|
||||
@@ -191,7 +194,6 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states = self.projection(hidden_states)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@@ -387,9 +389,11 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, output_attentions=False):
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
attn_residual = hidden_states
|
||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = attn_residual + hidden_states
|
||||
|
||||
@@ -414,10 +418,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||
|
||||
def forward(self, hidden_states, output_attentions=False):
|
||||
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||
attn_residual = hidden_states
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights, _ = self.attention(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = attn_residual + hidden_states
|
||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||
@@ -438,6 +444,7 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
@@ -445,6 +452,16 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens output 0
|
||||
hidden_states[~attention_mask] = 0.0
|
||||
|
||||
# extend attention_mask
|
||||
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.expand(
|
||||
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.layer_norm(hidden_states)
|
||||
@@ -454,7 +471,9 @@ class Wav2Vec2Encoder(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
@@ -486,6 +505,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
attention_mask=None,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
return_dict=True,
|
||||
@@ -493,6 +513,16 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attentions = () if output_attentions else None
|
||||
|
||||
if attention_mask is not None:
|
||||
# make sure padded tokens are not attended to
|
||||
hidden_states[~attention_mask] = 0
|
||||
|
||||
# extend attention_mask
|
||||
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.expand(
|
||||
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||
)
|
||||
|
||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||
hidden_states = hidden_states + position_embeddings
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
@@ -501,7 +531,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
||||
if output_hidden_states:
|
||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||
|
||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
||||
hidden_states, attn_weights = layer(
|
||||
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||
)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||
@@ -544,6 +576,21 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
|
||||
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||
"""
|
||||
Computes the output length of the convolutional layers
|
||||
"""
|
||||
|
||||
def _conv_out_length(input_length, kernel_size, stride):
|
||||
# 1D convolutional layer output length formula taken
|
||||
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||
|
||||
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||
|
||||
return input_lengths.to(torch.long)
|
||||
|
||||
|
||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||
@@ -572,6 +619,24 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
||||
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
|
||||
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
|
||||
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0,
|
||||
1]``:
|
||||
|
||||
- 1 for tokens that are **not masked**,
|
||||
- 0 for tokens that are **masked**.
|
||||
|
||||
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||
|
||||
.. warning::
|
||||
:obj:`attention_mask` should only be passed if the corresponding tokenizer has
|
||||
``config.return_attention_mask == True``. For all models whose tokenizer has
|
||||
``config.return_attention_mask == False``, such as `wav2vec2-base
|
||||
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, :obj:`attention_mask` should **not** be passed
|
||||
to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should
|
||||
simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield
|
||||
slightly different results depending on whether :obj:`input_values` is padded or not.
|
||||
|
||||
output_attentions (:obj:`bool`, `optional`):
|
||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||
tensors for more detail.
|
||||
@@ -606,6 +671,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -641,14 +707,33 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
hidden_states = self.feature_extractor(input_values)
|
||||
hidden_states = hidden_states.transpose(1, 2)
|
||||
|
||||
if attention_mask is not None:
|
||||
# compute real output lengths according to convolution formula
|
||||
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||
|
||||
attention_mask = torch.zeros(
|
||||
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||
)
|
||||
|
||||
# these two operations makes sure that all values
|
||||
# before the output lengths indices are attended to
|
||||
attention_mask[
|
||||
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||
] = 1
|
||||
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||
|
||||
hidden_states = self.feature_projection(hidden_states)
|
||||
|
||||
encoder_outputs = self.encoder(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = encoder_outputs[0]
|
||||
|
||||
if not return_dict:
|
||||
@@ -681,6 +766,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -755,6 +841,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
def forward(
|
||||
self,
|
||||
input_values,
|
||||
attention_mask=None,
|
||||
output_attentions=None,
|
||||
output_hidden_states=None,
|
||||
return_dict=None,
|
||||
@@ -795,6 +882,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
outputs = self.wav2vec2(
|
||||
input_values,
|
||||
attention_mask=attention_mask,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
@@ -802,6 +890,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
||||
|
||||
hidden_states = outputs[0]
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
|
||||
@@ -87,6 +87,26 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
||||
The token used for defining the end of a word.
|
||||
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to lowercase the output when decoding.
|
||||
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||
improve the performance for some models, *e.g.*, `wav2vec2-lv60
|
||||
<https://huggingface.co/models?search=lv60>`__.
|
||||
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||
Whether or not :meth:`~transformers.Wav2Vec2Tokenizer.__call__` should return :obj:`attention_mask`.
|
||||
|
||||
.. note::
|
||||
|
||||
Wav2Vec2 models that have set ``config.feat_extract_norm == "group"``, such as `wav2vec2-base
|
||||
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, have **not** been trained using
|
||||
:obj:`attention_mask`. For such models, :obj:`input_values` should simply be padded with 0 and no
|
||||
:obj:`attention_mask` should be passed.
|
||||
|
||||
For Wav2Vec2 models that have set ``config.feat_extract_norm == "layer"``, such as `wav2vec2-lv60
|
||||
<https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self>`__, :obj:`attention_mask` should be
|
||||
passed for batched inference.
|
||||
|
||||
**kwargs
|
||||
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||
"""
|
||||
@@ -100,7 +120,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
|
||||
},
|
||||
}
|
||||
model_input_names = ["input_values"]
|
||||
model_input_names = ["input_values", "attention_mask"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -111,6 +131,8 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
pad_token="<pad>",
|
||||
word_delimiter_token="|",
|
||||
do_lower_case=False,
|
||||
do_normalize=False,
|
||||
return_attention_mask=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(
|
||||
@@ -119,11 +141,16 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
eos_token=eos_token,
|
||||
pad_token=pad_token,
|
||||
do_lower_case=do_lower_case,
|
||||
do_normalize=do_normalize,
|
||||
return_attention_mask=return_attention_mask,
|
||||
word_delimiter_token=word_delimiter_token,
|
||||
**kwargs,
|
||||
)
|
||||
self._word_delimiter_token = word_delimiter_token
|
||||
|
||||
self.do_lower_case = do_lower_case
|
||||
self.return_attention_mask = return_attention_mask
|
||||
self.do_normalize = do_normalize
|
||||
|
||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||
self.encoder = json.load(vocab_handle)
|
||||
@@ -193,6 +220,10 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
if not is_batched:
|
||||
raw_speech = [raw_speech]
|
||||
|
||||
# zero-mean and unit-variance normalization
|
||||
if self.do_normalize:
|
||||
raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech]
|
||||
|
||||
# convert into correct format for padding
|
||||
encoded_inputs = BatchEncoding({"input_values": raw_speech})
|
||||
|
||||
@@ -201,7 +232,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
||||
padding=padding,
|
||||
max_length=max_length,
|
||||
pad_to_multiple_of=pad_to_multiple_of,
|
||||
return_attention_mask=False,
|
||||
return_attention_mask=self.return_attention_mask,
|
||||
return_tensors=return_tensors,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user