[Wav2Vec2] Improve Tokenizer & Model for batched inference (#10117)
* save intermediate * finish batch the same as fairseq * add normalization * fix batched input * add better comment * Update src/transformers/models/wav2vec2/modeling_wav2vec2.py * add nice docstring * add tokenizer tests * make all slow tests pass * finish PR * correct import
This commit is contained in:
committed by
GitHub
parent
2f3b5f4dcc
commit
495c157d6f
@@ -36,7 +36,10 @@ logger = logging.get_logger(__name__)
|
|||||||
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
_CONFIG_FOR_DOC = "Wav2Vec2Config"
|
||||||
|
|
||||||
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
||||||
"facebook/wav2vec2-base-960h"
|
"facebook/wav2vec2-base-960h",
|
||||||
|
"facebook/wav2vec2-large-960h",
|
||||||
|
"facebook/wav2vec2-large-960h-lv60",
|
||||||
|
"facebook/wav2vec2-large-960h-lv60-self",
|
||||||
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
# See all Wav2Vec2 models at https://huggingface.co/models?filter=wav2vec2
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -191,7 +194,6 @@ class Wav2Vec2FeatureProjection(nn.Module):
|
|||||||
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
self.dropout = nn.Dropout(config.feat_extract_dropout)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
hidden_states = hidden_states.transpose(1, 2)
|
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
hidden_states = self.projection(hidden_states)
|
hidden_states = self.projection(hidden_states)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
@@ -387,9 +389,11 @@ class Wav2Vec2EncoderLayer(nn.Module):
|
|||||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, output_attentions=False):
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||||
attn_residual = hidden_states
|
attn_residual = hidden_states
|
||||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
hidden_states, attn_weights, _ = self.attention(
|
||||||
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||||
|
)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = attn_residual + hidden_states
|
hidden_states = attn_residual + hidden_states
|
||||||
|
|
||||||
@@ -414,10 +418,12 @@ class Wav2Vec2EncoderLayerStableLayerNorm(nn.Module):
|
|||||||
self.feed_forward = Wav2Vec2FeedForward(config)
|
self.feed_forward = Wav2Vec2FeedForward(config)
|
||||||
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
||||||
|
|
||||||
def forward(self, hidden_states, output_attentions=False):
|
def forward(self, hidden_states, attention_mask=None, output_attentions=False):
|
||||||
attn_residual = hidden_states
|
attn_residual = hidden_states
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
hidden_states, attn_weights, _ = self.attention(hidden_states, output_attentions=output_attentions)
|
hidden_states, attn_weights, _ = self.attention(
|
||||||
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||||
|
)
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
hidden_states = attn_residual + hidden_states
|
hidden_states = attn_residual + hidden_states
|
||||||
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
hidden_states = hidden_states + self.feed_forward(self.final_layer_norm(hidden_states))
|
||||||
@@ -438,6 +444,7 @@ class Wav2Vec2Encoder(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -445,6 +452,16 @@ class Wav2Vec2Encoder(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# make sure padded tokens output 0
|
||||||
|
hidden_states[~attention_mask] = 0.0
|
||||||
|
|
||||||
|
# extend attention_mask
|
||||||
|
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.expand(
|
||||||
|
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||||
hidden_states = hidden_states + position_embeddings
|
hidden_states = hidden_states + position_embeddings
|
||||||
hidden_states = self.layer_norm(hidden_states)
|
hidden_states = self.layer_norm(hidden_states)
|
||||||
@@ -454,7 +471,9 @@ class Wav2Vec2Encoder(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
hidden_states, attn_weights = layer(
|
||||||
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||||
@@ -486,6 +505,7 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
attention_mask=None,
|
||||||
output_attentions=False,
|
output_attentions=False,
|
||||||
output_hidden_states=False,
|
output_hidden_states=False,
|
||||||
return_dict=True,
|
return_dict=True,
|
||||||
@@ -493,6 +513,16 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||||||
all_hidden_states = () if output_hidden_states else None
|
all_hidden_states = () if output_hidden_states else None
|
||||||
all_self_attentions = () if output_attentions else None
|
all_self_attentions = () if output_attentions else None
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# make sure padded tokens are not attended to
|
||||||
|
hidden_states[~attention_mask] = 0
|
||||||
|
|
||||||
|
# extend attention_mask
|
||||||
|
attention_mask = (1.0 - attention_mask[:, None, None, :].to(dtype=hidden_states.dtype)) * -10000.0
|
||||||
|
attention_mask = attention_mask.expand(
|
||||||
|
attention_mask.shape[0], 1, attention_mask.shape[-1], attention_mask.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
position_embeddings = self.pos_conv_embed(hidden_states)
|
position_embeddings = self.pos_conv_embed(hidden_states)
|
||||||
hidden_states = hidden_states + position_embeddings
|
hidden_states = hidden_states + position_embeddings
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
@@ -501,7 +531,9 @@ class Wav2Vec2EncoderStableLayerNorm(nn.Module):
|
|||||||
if output_hidden_states:
|
if output_hidden_states:
|
||||||
all_hidden_states = all_hidden_states + (hidden_states,)
|
all_hidden_states = all_hidden_states + (hidden_states,)
|
||||||
|
|
||||||
hidden_states, attn_weights = layer(hidden_states, output_attentions=output_attentions)
|
hidden_states, attn_weights = layer(
|
||||||
|
hidden_states, attention_mask=attention_mask, output_attentions=output_attentions
|
||||||
|
)
|
||||||
|
|
||||||
if output_attentions:
|
if output_attentions:
|
||||||
all_self_attentions = all_self_attentions + (attn_weights,)
|
all_self_attentions = all_self_attentions + (attn_weights,)
|
||||||
@@ -544,6 +576,21 @@ class Wav2Vec2PreTrainedModel(PreTrainedModel):
|
|||||||
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
if isinstance(module, (nn.Linear, nn.Conv1d)) and module.bias is not None:
|
||||||
module.bias.data.zero_()
|
module.bias.data.zero_()
|
||||||
|
|
||||||
|
def _get_feat_extract_output_lengths(self, input_lengths: torch.LongTensor):
|
||||||
|
"""
|
||||||
|
Computes the output length of the convolutional layers
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _conv_out_length(input_length, kernel_size, stride):
|
||||||
|
# 1D convolutional layer output length formula taken
|
||||||
|
# from https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html
|
||||||
|
return torch.floor((input_length - kernel_size) / stride + 1)
|
||||||
|
|
||||||
|
for kernel_size, stride in zip(self.config.conv_kernel, self.config.conv_stride):
|
||||||
|
input_lengths = _conv_out_length(input_lengths, kernel_size, stride)
|
||||||
|
|
||||||
|
return input_lengths.to(torch.long)
|
||||||
|
|
||||||
|
|
||||||
WAV_2_VEC_2_START_DOCSTRING = r"""
|
WAV_2_VEC_2_START_DOCSTRING = r"""
|
||||||
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
Wav2Vec2 was proposed in `wav2vec 2.0: A Framework for Self-Supervised Learning of Speech Representations
|
||||||
@@ -572,6 +619,24 @@ WAV_2_VEC_2_INPUTS_DOCSTRING = r"""
|
|||||||
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
|
soundfile`). To prepare the array into `input_values`, the :class:`~transformers.Wav2Vec2Tokenizer` should
|
||||||
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
|
be used for padding and conversion into a tensor of type `torch.FloatTensor`. See
|
||||||
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
|
:meth:`transformers.Wav2Vec2Tokenizer.__call__` for details.
|
||||||
|
attention_mask (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
||||||
|
Mask to avoid performing convolution and attention on padding token indices. Mask values selected in ``[0,
|
||||||
|
1]``:
|
||||||
|
|
||||||
|
- 1 for tokens that are **not masked**,
|
||||||
|
- 0 for tokens that are **masked**.
|
||||||
|
|
||||||
|
`What are attention masks? <../glossary.html#attention-mask>`__
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
:obj:`attention_mask` should only be passed if the corresponding tokenizer has
|
||||||
|
``config.return_attention_mask == True``. For all models whose tokenizer has
|
||||||
|
``config.return_attention_mask == False``, such as `wav2vec2-base
|
||||||
|
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, :obj:`attention_mask` should **not** be passed
|
||||||
|
to avoid degraded performance when doing batched inference. For such models :obj:`input_values` should
|
||||||
|
simply be padded with 0 and passed without :obj:`attention_mask`. Be aware that these models also yield
|
||||||
|
slightly different results depending on whether :obj:`input_values` is padded or not.
|
||||||
|
|
||||||
output_attentions (:obj:`bool`, `optional`):
|
output_attentions (:obj:`bool`, `optional`):
|
||||||
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
||||||
tensors for more detail.
|
tensors for more detail.
|
||||||
@@ -606,6 +671,7 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -641,14 +707,33 @@ class Wav2Vec2Model(Wav2Vec2PreTrainedModel):
|
|||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
|
||||||
hidden_states = self.feature_extractor(input_values)
|
hidden_states = self.feature_extractor(input_values)
|
||||||
|
hidden_states = hidden_states.transpose(1, 2)
|
||||||
|
|
||||||
|
if attention_mask is not None:
|
||||||
|
# compute real output lengths according to convolution formula
|
||||||
|
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1))
|
||||||
|
|
||||||
|
attention_mask = torch.zeros(
|
||||||
|
hidden_states.shape[:2], dtype=hidden_states.dtype, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# these two operations makes sure that all values
|
||||||
|
# before the output lengths indices are attended to
|
||||||
|
attention_mask[
|
||||||
|
(torch.arange(attention_mask.shape[0], device=hidden_states.device), output_lengths - 1)
|
||||||
|
] = 1
|
||||||
|
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
|
||||||
|
|
||||||
hidden_states = self.feature_projection(hidden_states)
|
hidden_states = self.feature_projection(hidden_states)
|
||||||
|
|
||||||
encoder_outputs = self.encoder(
|
encoder_outputs = self.encoder(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = encoder_outputs[0]
|
hidden_states = encoder_outputs[0]
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
@@ -681,6 +766,7 @@ class Wav2Vec2ForMaskedLM(Wav2Vec2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -755,6 +841,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
input_values,
|
input_values,
|
||||||
|
attention_mask=None,
|
||||||
output_attentions=None,
|
output_attentions=None,
|
||||||
output_hidden_states=None,
|
output_hidden_states=None,
|
||||||
return_dict=None,
|
return_dict=None,
|
||||||
@@ -795,6 +882,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
outputs = self.wav2vec2(
|
outputs = self.wav2vec2(
|
||||||
input_values,
|
input_values,
|
||||||
|
attention_mask=attention_mask,
|
||||||
output_attentions=output_attentions,
|
output_attentions=output_attentions,
|
||||||
output_hidden_states=output_hidden_states,
|
output_hidden_states=output_hidden_states,
|
||||||
return_dict=return_dict,
|
return_dict=return_dict,
|
||||||
@@ -802,6 +890,7 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
|
|||||||
|
|
||||||
hidden_states = outputs[0]
|
hidden_states = outputs[0]
|
||||||
hidden_states = self.dropout(hidden_states)
|
hidden_states = self.dropout(hidden_states)
|
||||||
|
|
||||||
logits = self.lm_head(hidden_states)
|
logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
if not return_dict:
|
if not return_dict:
|
||||||
|
|||||||
@@ -87,6 +87,26 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
The token used for padding, for example when batching sequences of different lengths.
|
The token used for padding, for example when batching sequences of different lengths.
|
||||||
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
word_delimiter_token (:obj:`str`, `optional`, defaults to :obj:`"|"`):
|
||||||
The token used for defining the end of a word.
|
The token used for defining the end of a word.
|
||||||
|
do_lower_case (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to lowercase the output when decoding.
|
||||||
|
do_normalize (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not to zero-mean unit-variance normalize the input. Normalizing can help to significantly
|
||||||
|
improve the performance for some models, *e.g.*, `wav2vec2-lv60
|
||||||
|
<https://huggingface.co/models?search=lv60>`__.
|
||||||
|
return_attention_mask (:obj:`bool`, `optional`, defaults to :obj:`False`):
|
||||||
|
Whether or not :meth:`~transformers.Wav2Vec2Tokenizer.__call__` should return :obj:`attention_mask`.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
|
||||||
|
Wav2Vec2 models that have set ``config.feat_extract_norm == "group"``, such as `wav2vec2-base
|
||||||
|
<https://huggingface.co/facebook/wav2vec2-base-960h>`__, have **not** been trained using
|
||||||
|
:obj:`attention_mask`. For such models, :obj:`input_values` should simply be padded with 0 and no
|
||||||
|
:obj:`attention_mask` should be passed.
|
||||||
|
|
||||||
|
For Wav2Vec2 models that have set ``config.feat_extract_norm == "layer"``, such as `wav2vec2-lv60
|
||||||
|
<https://huggingface.co/facebook/wav2vec2-large-960h-lv60-self>`__, :obj:`attention_mask` should be
|
||||||
|
passed for batched inference.
|
||||||
|
|
||||||
**kwargs
|
**kwargs
|
||||||
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
Additional keyword arguments passed along to :class:`~transformers.PreTrainedTokenizer`
|
||||||
"""
|
"""
|
||||||
@@ -100,7 +120,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
|
"facebook/wav2vec2-base-960h": "https://huggingface.co/facebook/wav2vec2-base-960h/resolve/main/tokenizer.json",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
model_input_names = ["input_values"]
|
model_input_names = ["input_values", "attention_mask"]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -111,6 +131,8 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
pad_token="<pad>",
|
pad_token="<pad>",
|
||||||
word_delimiter_token="|",
|
word_delimiter_token="|",
|
||||||
do_lower_case=False,
|
do_lower_case=False,
|
||||||
|
do_normalize=False,
|
||||||
|
return_attention_mask=False,
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@@ -119,11 +141,16 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
eos_token=eos_token,
|
eos_token=eos_token,
|
||||||
pad_token=pad_token,
|
pad_token=pad_token,
|
||||||
do_lower_case=do_lower_case,
|
do_lower_case=do_lower_case,
|
||||||
|
do_normalize=do_normalize,
|
||||||
|
return_attention_mask=return_attention_mask,
|
||||||
word_delimiter_token=word_delimiter_token,
|
word_delimiter_token=word_delimiter_token,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
self._word_delimiter_token = word_delimiter_token
|
self._word_delimiter_token = word_delimiter_token
|
||||||
|
|
||||||
self.do_lower_case = do_lower_case
|
self.do_lower_case = do_lower_case
|
||||||
|
self.return_attention_mask = return_attention_mask
|
||||||
|
self.do_normalize = do_normalize
|
||||||
|
|
||||||
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
with open(vocab_file, encoding="utf-8") as vocab_handle:
|
||||||
self.encoder = json.load(vocab_handle)
|
self.encoder = json.load(vocab_handle)
|
||||||
@@ -193,6 +220,10 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
if not is_batched:
|
if not is_batched:
|
||||||
raw_speech = [raw_speech]
|
raw_speech = [raw_speech]
|
||||||
|
|
||||||
|
# zero-mean and unit-variance normalization
|
||||||
|
if self.do_normalize:
|
||||||
|
raw_speech = [(x - np.mean(x)) / np.sqrt(np.var(x) + 1e-5) for x in raw_speech]
|
||||||
|
|
||||||
# convert into correct format for padding
|
# convert into correct format for padding
|
||||||
encoded_inputs = BatchEncoding({"input_values": raw_speech})
|
encoded_inputs = BatchEncoding({"input_values": raw_speech})
|
||||||
|
|
||||||
@@ -201,7 +232,7 @@ class Wav2Vec2Tokenizer(PreTrainedTokenizer):
|
|||||||
padding=padding,
|
padding=padding,
|
||||||
max_length=max_length,
|
max_length=max_length,
|
||||||
pad_to_multiple_of=pad_to_multiple_of,
|
pad_to_multiple_of=pad_to_multiple_of,
|
||||||
return_attention_mask=False,
|
return_attention_mask=self.return_attention_mask,
|
||||||
return_tensors=return_tensors,
|
return_tensors=return_tensors,
|
||||||
verbose=verbose,
|
verbose=verbose,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import math
|
import math
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from tests.test_modeling_common import floats_tensor
|
from tests.test_modeling_common import floats_tensor, random_attention_mask
|
||||||
from transformers import is_torch_available
|
from transformers import is_torch_available
|
||||||
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
from transformers.testing_utils import require_datasets, require_soundfile, require_torch, slow, torch_device
|
||||||
|
|
||||||
@@ -93,6 +93,7 @@ class Wav2Vec2ModelTester:
|
|||||||
|
|
||||||
def prepare_config_and_inputs(self):
|
def prepare_config_and_inputs(self):
|
||||||
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
input_values = floats_tensor([self.batch_size, self.seq_length], self.vocab_size)
|
||||||
|
attention_mask = random_attention_mask([self.batch_size, self.seq_length])
|
||||||
|
|
||||||
config = Wav2Vec2Config(
|
config = Wav2Vec2Config(
|
||||||
hidden_size=self.hidden_size,
|
hidden_size=self.hidden_size,
|
||||||
@@ -115,20 +116,48 @@ class Wav2Vec2ModelTester:
|
|||||||
vocab_size=self.vocab_size,
|
vocab_size=self.vocab_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
return config, input_values
|
return config, input_values, attention_mask
|
||||||
|
|
||||||
def create_and_check_model(self, config, input_values):
|
def create_and_check_model(self, config, input_values, attention_mask):
|
||||||
model = Wav2Vec2Model(config=config)
|
model = Wav2Vec2Model(config=config)
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
model.eval()
|
model.eval()
|
||||||
result = model(input_values)
|
result = model(input_values, attention_mask=attention_mask)
|
||||||
self.parent.assertEqual(
|
self.parent.assertEqual(
|
||||||
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
result.last_hidden_state.shape, (self.batch_size, self.output_seq_length, self.hidden_size)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def create_and_check_batch_inference(self, config, input_values, *args):
|
||||||
|
# Not sure how to make this test pass at the moment. Batched input yields
|
||||||
|
# same results as official fairseq implementation, but gives different results
|
||||||
|
# depending on whether batched input is used or not
|
||||||
|
# check: https://github.com/pytorch/fairseq/issues/3227
|
||||||
|
model = Wav2Vec2Model(config=config)
|
||||||
|
model.to(torch_device)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
input_values = input_values[:3]
|
||||||
|
attention_mask = torch.ones(input_values.shape, device=torch_device, dtype=torch.bool)
|
||||||
|
|
||||||
|
input_lengths = [input_values.shape[-1] // i for i in [4, 2, 1]]
|
||||||
|
|
||||||
|
# pad input
|
||||||
|
for i in range(len(input_lengths)):
|
||||||
|
input_values[i, input_lengths[i] :] = 0.0
|
||||||
|
attention_mask[i, input_lengths[i] :] = 0.0
|
||||||
|
|
||||||
|
batch_outputs = model(input_values, attention_mask=attention_mask).last_hidden_state
|
||||||
|
|
||||||
|
for i in range(input_values.shape[0]):
|
||||||
|
input_slice = input_values[i : i + 1, : input_lengths[i]]
|
||||||
|
output = model(input_slice).last_hidden_state
|
||||||
|
|
||||||
|
batch_output = batch_outputs[i : i + 1, : output.shape[1]]
|
||||||
|
self.parent.assertTrue(torch.allclose(output, batch_output, atol=1e-3))
|
||||||
|
|
||||||
def prepare_config_and_inputs_for_common(self):
|
def prepare_config_and_inputs_for_common(self):
|
||||||
config, input_values = self.prepare_config_and_inputs()
|
config, input_values, attention_mask = self.prepare_config_and_inputs()
|
||||||
inputs_dict = {"input_values": input_values}
|
inputs_dict = {"input_values": input_values, "attention_mask": attention_mask}
|
||||||
return config, inputs_dict
|
return config, inputs_dict
|
||||||
|
|
||||||
|
|
||||||
@@ -222,6 +251,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
|
|||||||
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
self.model_tester.create_and_check_model(*config_and_inputs)
|
self.model_tester.create_and_check_model(*config_and_inputs)
|
||||||
|
|
||||||
|
def test_batched_inference(self):
|
||||||
|
config_and_inputs = self.model_tester.prepare_config_and_inputs()
|
||||||
|
self.model_tester.create_and_check_batch_inference(*config_and_inputs)
|
||||||
|
|
||||||
# Wav2Vec2 has no inputs_embeds
|
# Wav2Vec2 has no inputs_embeds
|
||||||
def test_inputs_embeds(self):
|
def test_inputs_embeds(self):
|
||||||
pass
|
pass
|
||||||
@@ -288,7 +321,7 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
|
|
||||||
return ds["speech"][:num_samples]
|
return ds["speech"][:num_samples]
|
||||||
|
|
||||||
def test_inference_masked_lm_normal(self):
|
def test_inference_ctc_normal(self):
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
@@ -306,16 +339,16 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
EXPECTED_TRANSCRIPTIONS = ["a man said to the universe sir i exist"]
|
||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_inference_masked_lm_normal_batched(self):
|
def test_inference_ctc_normal_batched(self):
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h")
|
||||||
model.to(torch_device)
|
model.to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-base-960h", do_lower_case=True)
|
||||||
|
|
||||||
input_speech = self._load_datasamples(2)
|
input_speech = self._load_datasamples(2)
|
||||||
|
|
||||||
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||||
torch_device
|
|
||||||
)
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_values).logits
|
logits = model(input_values).logits
|
||||||
@@ -329,18 +362,19 @@ class Wav2Vec2ModelIntegrationTest(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
self.assertListEqual(predicted_trans, EXPECTED_TRANSCRIPTIONS)
|
||||||
|
|
||||||
def test_inference_masked_lm_robust_batched(self):
|
def test_inference_ctc_robust_batched(self):
|
||||||
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-960h-lv60-self").to(torch_device)
|
||||||
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained("facebook/wav2vec2-large-960h-lv60-self", do_lower_case=True)
|
||||||
|
|
||||||
input_speech = self._load_datasamples(4)
|
input_speech = self._load_datasamples(4)
|
||||||
|
|
||||||
input_values = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True).input_values.to(
|
inputs = tokenizer(input_speech, return_tensors="pt", padding=True, truncation=True)
|
||||||
torch_device
|
|
||||||
)
|
input_values = inputs.input_values.to(torch_device)
|
||||||
|
attention_mask = inputs.attention_mask.to(torch_device)
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
logits = model(input_values).logits
|
logits = model(input_values, attention_mask=attention_mask).logits
|
||||||
|
|
||||||
predicted_ids = torch.argmax(logits, dim=-1)
|
predicted_ids = torch.argmax(logits, dim=-1)
|
||||||
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
predicted_trans = tokenizer.batch_decode(predicted_ids)
|
||||||
|
|||||||
@@ -23,7 +23,10 @@ import unittest
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES, Wav2Vec2Tokenizer
|
from transformers import WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||||
|
from transformers.models.wav2vec2 import Wav2Vec2Config, Wav2Vec2Tokenizer
|
||||||
|
from transformers.models.wav2vec2.tokenization_wav2vec2 import VOCAB_FILES_NAMES
|
||||||
|
from transformers.testing_utils import slow
|
||||||
|
|
||||||
|
|
||||||
global_rng = random.Random()
|
global_rng = random.Random()
|
||||||
@@ -299,3 +302,46 @@ class Wav2Vec2TokenizerTest(unittest.TestCase):
|
|||||||
for parameter_name, parameter in signature.parameters.items():
|
for parameter_name, parameter in signature.parameters.items():
|
||||||
if parameter.default != inspect.Parameter.empty:
|
if parameter.default != inspect.Parameter.empty:
|
||||||
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
self.assertIn(parameter_name, tokenizer.init_kwargs)
|
||||||
|
|
||||||
|
def test_zero_mean_unit_variance_normalization(self):
|
||||||
|
tokenizer = self.get_tokenizer(do_normalize=True)
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
processed = tokenizer(speech_inputs, padding="longest")
|
||||||
|
input_values = processed.input_values
|
||||||
|
|
||||||
|
def _check_zero_mean_unit_variance(input_vector):
|
||||||
|
self.assertTrue(np.abs(np.mean(input_vector)) < 1e-3)
|
||||||
|
self.assertTrue(np.abs(np.var(input_vector) - 1) < 1e-3)
|
||||||
|
|
||||||
|
_check_zero_mean_unit_variance(input_values[0, :800])
|
||||||
|
_check_zero_mean_unit_variance(input_values[1, :1000])
|
||||||
|
_check_zero_mean_unit_variance(input_values[2])
|
||||||
|
|
||||||
|
def test_return_attention_mask(self):
|
||||||
|
speech_inputs = [floats_list((1, x))[0] for x in range(800, 1400, 200)]
|
||||||
|
|
||||||
|
# default case -> no attention_mask is returned
|
||||||
|
tokenizer = self.get_tokenizer()
|
||||||
|
processed = tokenizer(speech_inputs)
|
||||||
|
self.assertNotIn("attention_mask", processed)
|
||||||
|
|
||||||
|
# wav2vec2-lv60 -> return attention_mask
|
||||||
|
tokenizer = self.get_tokenizer(return_attention_mask=True)
|
||||||
|
processed = tokenizer(speech_inputs, padding="longest")
|
||||||
|
|
||||||
|
self.assertIn("attention_mask", processed)
|
||||||
|
self.assertListEqual(list(processed.attention_mask.shape), list(processed.input_values.shape))
|
||||||
|
self.assertListEqual(processed.attention_mask.sum(-1).tolist(), [800, 1000, 1200])
|
||||||
|
|
||||||
|
@slow
|
||||||
|
def test_pretrained_checkpoints_are_set_correctly(self):
|
||||||
|
# this test makes sure that models that are using
|
||||||
|
# group norm don't have their tokenizer return the
|
||||||
|
# attention_mask
|
||||||
|
for model_id in WAV_2_VEC_2_PRETRAINED_MODEL_ARCHIVE_LIST:
|
||||||
|
config = Wav2Vec2Config.from_pretrained(model_id)
|
||||||
|
tokenizer = Wav2Vec2Tokenizer.from_pretrained(model_id)
|
||||||
|
|
||||||
|
# only "layer" feature extraction norm should make use of
|
||||||
|
# attention_mask
|
||||||
|
self.assertEqual(tokenizer.return_attention_mask, config.feat_extract_norm == "layer")
|
||||||
|
|||||||
Reference in New Issue
Block a user